diff --git a/talkingface-toolkit-main/README.md b/talkingface-toolkit-main/README.md
new file mode 100644
index 00000000..b6f46dbb
--- /dev/null
+++ b/talkingface-toolkit-main/README.md
@@ -0,0 +1,216 @@
+python meta_portrait_base_inference.py --save_dir ./saved --config ./talkingface/properties/model/meta_portrait_base_config/meta_portrait_256_eval.yaml --ckpt ./saved/ckpt_base.pth.tar
+
+python meta_portrait_base_main.py --config ./talkingface/properties/model/meta_portrait_base_config/meta_portrait_256_pretrain_warp.yaml --fp16 --stage Warp --task Pretrain
+
+
+
+# talkingface-toolkit
+## 框架整体介绍
+### checkpoints
+主要保存的是训练和评估模型所需要的额外的预训练模型,在对应文件夹的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/checkpoints/README.md)有更详细的介绍
+
+### datset
+存放数据集以及数据集预处理之后的数据,详细内容见dataset里的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/dataset/README.md)
+
+### saved
+存放训练过程中保存的模型checkpoint, 训练过程中保存模型时自动创建
+
+### talkingface
+主要功能模块,包括所有核心代码
+
+#### config
+根据模型和数据集名称自动生成所有模型、数据集、训练、评估等相关的配置信息
+```
+config/
+
+├── configurator.py
+
+```
+#### data
+- dataprocess:模型特有的数据处理代码,(可以是对方仓库自己实现的音频特征提取、推理时的数据处理)。如果实现的模型有这个需求,就要建立一对应的文件
+- dataset:每个模型都要重载`torch.utils.data.Dataset` 用于加载数据。每个模型都要有一个`model_name+'_dataset.py'`文件. `__getitem__()`方法的返回值应处理成字典类型的数据。 (核心部分)
+```
+data/
+
+├── dataprocess
+
+| ├── wav2lip_process.py
+
+| ├── xxxx_process.py
+
+├── dataset
+
+| ├── wav2lip_dataset.py
+
+| ├── xxx_dataset.py
+```
+
+#### evaluate
+主要涉及模型评估的代码
+LSE metric 需要的数据是生成的视频列表
+SSIM metric 需要的数据是生成的视频和真实的视频列表
+
+#### model
+实现的模型的网络和对应的方法 (核心部分)
+
+主要分三类:
+- audio-driven (音频驱动)
+- image-driven (图像驱动)
+- nerf-based (基于神经辐射场的方法)
+
+```
+model/
+
+├── audio_driven_talkingface
+
+| ├── wav2lip.py
+
+├── image_driven_talkingface
+
+| ├── xxxx.py
+
+├── nerf_based_talkingface
+
+| ├── xxxx.py
+
+├── abstract_talkingface.py
+
+```
+
+#### properties
+保存默认配置文件,包括:
+- 数据集配置文件
+- 模型配置文件
+- 通用配置文件
+
+需要根据对应模型和数据集增加对应的配置文件,通用配置文件`overall.yaml`一般不做修改
+```
+properties/
+
+├── dataset
+
+| ├── xxx.yaml
+
+├── model
+
+| ├── xxx.yaml
+
+├── overall.yaml
+
+```
+
+#### quick_start
+通用的启动文件,根据传入参数自动配置数据集和模型,然后训练和评估(一般不需要修改)
+```
+quick_start/
+
+├── quick_start.py
+
+```
+
+#### trainer
+训练、评估函数的主类。在trainer中,如果可以使用基类`Trainer`实现所有功能,则不需要写一个新的。如果模型训练有一些特有部分,则需要重载`Trainer`。需要重载部分可能主要集中于: `_train_epoch()`, `_valid_epoch()`。 重载的`Trainer`应该命名为:`{model_name}Trainer`
+```
+trainer/
+
+├── trainer.py
+
+```
+
+#### utils
+公用的工具类,包括`s3fd`人脸检测,视频抽帧、视频抽音频方法。还包括根据参数配置找对应的模型类、数据类等方法。
+一般不需要修改,但可以适当添加一些必须的且相对普遍的数据处理文件。
+
+## 使用方法
+### 环境要求
+- `python=3.8`
+- `torch==1.13.1+cu116`(gpu版,若设备不支持cuda可以使用cpu版)
+- `numpy==1.20.3`
+- `librosa==0.10.1`
+
+尽量保证上面几个包的版本一致
+
+提供了两种配置其他环境的方法:
+```
+pip install -r requirements.txt
+
+or
+
+conda env create -f environment.yml
+```
+
+建议使用conda虚拟环境!!!
+
+### 训练和评估
+
+```bash
+python run_talkingface.py --model=xxxx --dataset=xxxx (--other_parameters=xxxxxx)
+```
+
+### 权重文件
+
+- LSE评估需要的权重: syncnet_v2.model [百度网盘下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc)
+- wav2lip需要的lip expert 权重:lipsync_expert.pth [百度网下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc)
+
+## 可选论文:
+### Aduio_driven talkingface
+| 模型简称 | 论文 | 代码仓库 |
+|:--------:|:--------:|:--------:|
+| MakeItTalk | [paper](https://arxiv.org/abs/2004.12992) | [code](https://github.com/yzhou359/MakeItTalk) |
+| MEAD | [paper](https://wywu.github.io/projects/MEAD/support/MEAD.pdf) | [code](https://github.com/uniBruce/Mead) |
+| RhythmicHead | [paper](https://arxiv.org/pdf/2007.08547v1.pdf) | [code](https://github.com/lelechen63/Talking-head-Generation-with-Rhythmic-Head-Motion) |
+| PC-AVS | [paper](https://arxiv.org/abs/2104.11116) | [code](https://github.com/Hangz-nju-cuhk/Talking-Face_PC-AVS) |
+| EVP | [paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Ji_Audio-Driven_Emotional_Video_Portraits_CVPR_2021_paper.pdf) | [code](https://github.com/jixinya/EVP) |
+| LSP | [paper](https://arxiv.org/abs/2109.10595) | [code](https://github.com/YuanxunLu/LiveSpeechPortraits) |
+| EAMM | [paper](https://arxiv.org/pdf/2205.15278.pdf) | [code](https://github.com/jixinya/EAMM/) |
+| DiffTalk | [paper](https://arxiv.org/abs/2301.03786) | [code](https://github.com/sstzal/DiffTalk) |
+| TalkLip | [paper](https://arxiv.org/pdf/2303.17480.pdf) | [code](https://github.com/Sxjdwang/TalkLip) |
+| EmoGen | [paper](https://arxiv.org/pdf/2303.11548.pdf) | [code](https://github.com/sahilg06/EmoGen) |
+| SadTalker | [paper](https://arxiv.org/abs/2211.12194) | [code](https://github.com/OpenTalker/SadTalker) |
+| HyperLips | [paper](https://arxiv.org/abs/2310.05720) | [code](https://github.com/semchan/HyperLips) |
+| PHADTF | [paper](http://arxiv.org/abs/2002.10137) | [code](https://github.com/yiranran/Audio-driven-TalkingFace-HeadPose) |
+| VideoReTalking | [paper](https://arxiv.org/abs/2211.14758) | [code](https://github.com/OpenTalker/video-retalking#videoretalking--audio-based-lip-synchronization-for-talking-head-video-editing-in-the-wild-)
+| |
+
+
+
+### Image_driven talkingface
+| 模型简称 | 论文 | 代码仓库 |
+|:--------:|:--------:|:--------:|
+| PIRenderer | [paper](https://arxiv.org/pdf/2109.08379.pdf) | [code](https://github.com/RenYurui/PIRender) |
+| StyleHEAT | [paper](https://arxiv.org/pdf/2203.04036.pdf) | [code](https://github.com/OpenTalker/StyleHEAT) |
+| MetaPortrait | [paper](https://arxiv.org/abs/2212.08062) | [code](https://github.com/Meta-Portrait/MetaPortrait) |
+| |
+### Nerf-based talkingface
+| 模型简称 | 论文 | 代码仓库 |
+|:--------:|:--------:|:--------:|
+| AD-NeRF | [paper](https://arxiv.org/abs/2103.11078) | [code](https://github.com/YudongGuo/AD-NeRF) |
+| GeneFace | [paper](https://arxiv.org/abs/2301.13430) | [code](https://github.com/yerfor/GeneFace) |
+| DFRF | [paper](https://arxiv.org/abs/2207.11770) | [code](https://github.com/sstzal/DFRF) |
+| |
+### text_to_speech
+| 模型简称 | 论文 | 代码仓库 |
+|:--------:|:--------:|:--------:|
+| VITS | [paper](https://arxiv.org/abs/2106.06103) | [code](https://github.com/jaywalnut310/vits) |
+| Glow TTS | [paper](https://arxiv.org/abs/2005.11129) | [code](https://github.com/jaywalnut310/glow-tts) |
+| FastSpeech2 | [paper](https://arxiv.org/abs/2006.04558v1) | [code](https://github.com/ming024/FastSpeech2) |
+| StyleTTS2 | [paper](https://arxiv.org/abs/2306.07691) | [code](https://github.com/yl4579/StyleTTS2) |
+| Grad-TTS | [paper](https://arxiv.org/abs/2105.06337) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS) |
+| FastSpeech | [paper](https://arxiv.org/abs/1905.09263) | [code](https://github.com/xcmyz/FastSpeech) |
+| |
+### voice_conversion
+| 模型简称 | 论文 | 代码仓库 |
+|:--------:|:--------:|:--------:|
+| StarGAN-VC | [paper](http://www.kecl.ntt.co.jp/people/kameoka.hirokazu/Demos/stargan-vc2/index.html) | [code](https://github.com/kamepong/StarGAN-VC) |
+| Emo-StarGAN | [paper](https://www.researchgate.net/publication/373161292_Emo-StarGAN_A_Semi-Supervised_Any-to-Many_Non-Parallel_Emotion-Preserving_Voice_Conversion) | [code](https://github.com/suhitaghosh10/emo-stargan) |
+| adaptive-VC | [paper](https://arxiv.org/abs/1904.05742) | [code](https://github.com/jjery2243542/adaptive_voice_conversion) |
+| DiffVC | [paper](https://arxiv.org/abs/2109.13821) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC) |
+| Assem-VC | [paper](https://arxiv.org/abs/2104.00931) | [code](https://github.com/maum-ai/assem-vc) |
+| |
+
+## 作业要求
+- 确保可以仅在命令行输入模型和数据集名称就可以训练、验证。(部分仓库没有提供训练代码的,可以不训练)
+- 每个组都要提交一个README文件,写明完成的功能、最终实现的训练、验证截图、所使用的依赖、成员分工等。
+
+
+
diff --git a/talkingface-toolkit-main/checkpoints/README.md b/talkingface-toolkit-main/checkpoints/README.md
new file mode 100644
index 00000000..0a1432d6
--- /dev/null
+++ b/talkingface-toolkit-main/checkpoints/README.md
@@ -0,0 +1,16 @@
+这个文件夹中保存的是,模型训练或验证过程中用到的一些额外的预训练权重如:
+- wav2lip中用到的syncnet权重
+- 计算合成视频lip-audio同步LSE用到的syncnet-v2权重
+- .......
+
+目录结构为:
+```
+checkpoints/
+
+├── LSE
+| ├── syncnet_v2.model ()
+
+├── Wav2Lip
+| ├── lipsync_expert.pth ()
+
+```
diff --git a/talkingface-toolkit-main/dataset/README.md b/talkingface-toolkit-main/dataset/README.md
new file mode 100644
index 00000000..8a2201ad
--- /dev/null
+++ b/talkingface-toolkit-main/dataset/README.md
@@ -0,0 +1,41 @@
+这个文件夹中保存的是数据集如:
+- lrw
+- lrs2
+- mead
+- .......
+
+数据集处理的一般格式为:
+
+```
+dataset/
+
+├── lrs2
+
+| ├── data (存放数据集的原始数据)
+
+| ├── filelist (保存的是数据集划分)
+
+| │ ├── train.txt
+
+| │ ├── val.txt
+
+| │ ├── test.txt
+
+| ├── preprocessed_data (具体路径内容可以参考talkingface.utils.data_preprocess文件中处理lrs2时候的路径,主要存储的是视频抽帧后的图像文件和音频文件)
+```
+
+preprocessed_data的数据路径一般表示为:
+```
+preprocessed_root (lrs2_preprocessed)/
+
+├── list of folders
+
+| ├── Folders with five-digit numbered video IDs
+
+| │ ├── *.jpg
+
+| │ ├── audio.wav
+
+```
+
+数据集存储尽量按照这个格式来,数据集的划分也尽量按照train.txt val.txt和test.txt文件来
diff --git a/talkingface-toolkit-main/environment.yml b/talkingface-toolkit-main/environment.yml
new file mode 100644
index 00000000..09a8595b
--- /dev/null
+++ b/talkingface-toolkit-main/environment.yml
@@ -0,0 +1,138 @@
+name: torch38
+channels:
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - ca-certificates=2023.08.22=h06a4308_0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - libffi=3.4.4=h6a678d5_0
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - ncurses=6.4=h6a678d5_0
+ - openssl=3.0.12=h7f8727e_0
+ - pip=23.3=py38h06a4308_0
+ - python=3.8.18=h955ad1f_0
+ - readline=8.2=h5eee18b_0
+ - setuptools=68.0.0=py38h06a4308_0
+ - sqlite=3.41.2=h5eee18b_0
+ - tk=8.6.12=h1ccaba5_0
+ - wheel=0.41.2=py38h06a4308_0
+ - xz=5.4.2=h5eee18b_0
+ - zlib=1.2.13=h5eee18b_0
+ - pip:
+ - absl-py==2.0.0
+ - addict==2.4.0
+ - aiosignal==1.3.1
+ - appdirs==1.4.4
+ - attrs==23.1.0
+ - audioread==3.0.1
+ - basicsr==1.3.4.7
+ - cachetools==5.3.2
+ - certifi==2020.12.5
+ - cffi==1.16.0
+ - charset-normalizer==3.3.2
+ - click==8.1.7
+ - cloudpickle==3.0.0
+ - colorama==0.4.6
+ - colorlog==6.7.0
+ - contourpy==1.1.1
+ - cycler==0.12.1
+ - decorator==5.1.1
+ - dlib==19.22.1
+ - docker-pycreds==0.4.0
+ - face-alignment==1.3.5
+ - ffmpeg==1.4
+ - filelock==3.13.1
+ - fonttools==4.44.0
+ - frozenlist==1.4.0
+ - future==0.18.3
+ - gitdb==4.0.11
+ - gitpython==3.1.40
+ - glob2==0.7
+ - google-auth==2.23.4
+ - google-auth-oauthlib==0.4.6
+ - grpcio==1.59.2
+ - hyperopt==0.2.5
+ - idna==3.4
+ - imageio==2.9.0
+ - imageio-ffmpeg==0.4.5
+ - importlib-metadata==6.8.0
+ - importlib-resources==6.1.0
+ - joblib==1.3.2
+ - jsonschema==4.19.2
+ - jsonschema-specifications==2023.7.1
+ - kiwisolver==1.4.5
+ - kornia==0.5.5
+ - lazy-loader==0.3
+ - librosa==0.10.1
+ - llvmlite==0.37.0
+ - lmdb==1.2.1
+ - lws==1.2.7
+ - markdown==3.5.1
+ - markupsafe==2.1.3
+ - matplotlib==3.6.3
+ - msgpack==1.0.7
+ - networkx==3.1
+ - numba==0.54.1
+ - numpy==1.20.3
+ - oauthlib==3.2.2
+ - opencv-python==3.4.9.33
+ - packaging==23.2
+ - pandas==1.3.4
+ - pathtools==0.1.2
+ - pillow==6.2.1
+ - pkgutil-resolve-name==1.3.10
+ - platformdirs==3.11.0
+ - plotly==5.18.0
+ - pooch==1.8.0
+ - protobuf==4.25.0
+ - psutil==5.9.6
+ - pyasn1==0.5.0
+ - pyasn1-modules==0.3.0
+ - pycparser==2.21
+ - pyparsing==3.1.1
+ - python-dateutil==2.8.2
+ - python-speech-features==0.6
+ - pytorch-fid==0.3.0
+ - pytz==2023.3.post1
+ - pywavelets==1.4.1
+ - pyyaml==5.3.1
+ - ray==2.6.3
+ - referencing==0.30.2
+ - requests==2.31.0
+ - requests-oauthlib==1.3.1
+ - rpds-py==0.12.0
+ - rsa==4.9
+ - scikit-image==0.16.2
+ - scikit-learn==1.3.2
+ - scipy==1.5.0
+ - sentry-sdk==1.34.0
+ - setproctitle==1.3.3
+ - six==1.16.0
+ - smmap==5.0.1
+ - soundfile==0.12.1
+ - soxr==0.3.7
+ - tabulate==0.9.0
+ - tb-nightly==2.12.0a20230126
+ - tenacity==8.2.3
+ - tensorboard==2.7.0
+ - tensorboard-data-server==0.6.1
+ - tensorboard-plugin-wit==1.8.1
+ - texttable==1.7.0
+ - thop==0.1.1-2209072238
+ - threadpoolctl==3.2.0
+ - tomli==2.0.1
+ - torch==1.13.1+cu116
+ - torchaudio==0.13.1+cu116
+ - torchvision==0.14.1+cu116
+ - tqdm==4.66.1
+ - trimesh==3.9.20
+ - typing-extensions==4.8.0
+ - tzdata==2023.3
+ - urllib3==2.0.7
+ - wandb==0.15.12
+ - werkzeug==3.0.1
+ - yapf==0.40.2
+ - zipp==3.17.0
diff --git a/talkingface-toolkit-main/requirements.txt b/talkingface-toolkit-main/requirements.txt
new file mode 100644
index 00000000..1605c1fe
--- /dev/null
+++ b/talkingface-toolkit-main/requirements.txt
@@ -0,0 +1,114 @@
+absl-py==2.0.0
+addict==2.4.0
+aiosignal==1.3.1
+appdirs==1.4.4
+attrs==23.1.0
+audioread==3.0.1
+basicsr==1.3.4.7
+cachetools==5.3.2
+certifi==2020.12.5
+cffi==1.16.0
+charset-normalizer==3.3.2
+click==8.1.7
+cloudpickle==3.0.0
+colorama==0.4.6
+colorlog==6.7.0
+contourpy==1.1.1
+cycler==0.12.1
+decorator==5.1.1
+dlib==19.22.1
+docker-pycreds==0.4.0
+face-alignment==1.3.5
+ffmpeg==1.4
+filelock==3.13.1
+fonttools==4.44.0
+frozenlist==1.4.0
+future==0.18.3
+gitdb==4.0.11
+GitPython==3.1.40
+glob2==0.7
+google-auth==2.23.4
+google-auth-oauthlib==0.4.6
+grpcio==1.59.2
+hyperopt==0.2.5
+idna==3.4
+imageio==2.9.0
+imageio-ffmpeg==0.4.5
+importlib-metadata==6.8.0
+importlib-resources==6.1.0
+joblib==1.3.2
+jsonschema==4.19.2
+jsonschema-specifications==2023.7.1
+kiwisolver==1.4.5
+kornia==0.5.5
+lazy_loader==0.3
+librosa==0.10.1
+llvmlite==0.37.0
+lmdb==1.2.1
+lws==1.2.7
+Markdown==3.5.1
+MarkupSafe==2.1.3
+matplotlib==3.6.3
+msgpack==1.0.7
+networkx==3.1
+numba==0.54.1
+numpy==1.20.3
+oauthlib==3.2.2
+opencv-python==3.4.9.33
+packaging==23.2
+pandas==1.3.4
+pathtools==0.1.2
+Pillow==6.2.1
+pkgutil_resolve_name==1.3.10
+platformdirs==3.11.0
+plotly==5.18.0
+pooch==1.8.0
+protobuf==4.25.0
+psutil==5.9.6
+pyasn1==0.5.0
+pyasn1-modules==0.3.0
+pycparser==2.21
+pyparsing==3.1.1
+python-dateutil==2.8.2
+python-speech-features==0.6
+pytorch-fid==0.3.0
+pytz==2023.3.post1
+PyWavelets==1.4.1
+PyYAML==5.3.1
+ray==2.6.3
+referencing==0.30.2
+requests==2.31.0
+requests-oauthlib==1.3.1
+rpds-py==0.12.0
+rsa==4.9
+scikit-image==0.16.2
+scikit-learn==1.3.2
+scipy==1.5.0
+sentry-sdk==1.34.0
+setproctitle==1.3.3
+six==1.16.0
+smmap==5.0.1
+soundfile==0.12.1
+soxr==0.3.7
+tabulate==0.9.0
+tb-nightly==2.12.0a20230126
+tenacity==8.2.3
+tensorboard==2.7.0
+tensorboard-data-server==0.6.1
+tensorboard-plugin-wit==1.8.1
+texttable==1.7.0
+thop==0.1.1.post2209072238
+threadpoolctl==3.2.0
+tomli==2.0.1
+torch==1.13.1+cu116
+torchaudio==0.13.1+cu116
+torchvision==0.14.1+cu116
+tqdm==4.66.1
+trimesh==3.9.20
+typing_extensions==4.8.0
+tzdata==2023.3
+urllib3==2.0.7
+wandb==0.15.12
+Werkzeug==3.0.1
+yapf==0.40.2
+zipp==3.17.0
diff --git a/talkingface-toolkit-main/run_talkingface.py b/talkingface-toolkit-main/run_talkingface.py
new file mode 100644
index 00000000..3989d566
--- /dev/null
+++ b/talkingface-toolkit-main/run_talkingface.py
@@ -0,0 +1,24 @@
+import argparse
+from talkingface.quick_start import run
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", "-m", type=str, default="BPR", help="name of models")
+ parser.add_argument(
+ "--dataset", "-d", type=str, default=None, help="name of datasets"
+ )
+ parser.add_argument("--evaluate_model_file", type=str, default=None, help="The model file you want to evaluate")
+ parser.add_argument("--config_files", type=str, default=None, help="config files")
+
+
+ args, _ = parser.parse_known_args()
+
+ config_file_list = (
+ args.config_files.strip().split(" ") if args.config_files else None
+ )
+ run(
+ args.model,
+ args.dataset,
+ config_file_list=config_file_list,
+ evaluate_model_file=args.evaluate_model_file
+ )
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/config/__init__.py b/talkingface-toolkit-main/talkingface/config/__init__.py
new file mode 100644
index 00000000..25fcd4c9
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/config/__init__.py
@@ -0,0 +1 @@
+from talkingface.config.configurator import Config
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/config/configurator.py b/talkingface-toolkit-main/talkingface/config/configurator.py
new file mode 100644
index 00000000..3fad0f4c
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/config/configurator.py
@@ -0,0 +1,335 @@
+import re
+import os
+import sys
+import yaml
+from logging import getLogger
+
+from talkingface.utils import(
+ get_model,
+ # Enum,
+ # ModelType,
+ # InputType,
+ general_arguments,
+ training_arguments,
+ evaluation_arguments,
+ set_color
+)
+
+class Config(object):
+ """Configurator module that load the defined parameters.
+
+ Configurator module will first load the default parameters from the fixed properties in TalkingFace and then
+ load parameters from the external input.
+
+ External input supports three kind of forms: config file, command line and parameter dictionaries.
+
+ - config file: It's a file that record the parameters to be modified or added. It should be in ``yaml`` format,
+ e.g. a config file is 'example.yaml', the content is:
+
+ learning_rate: 0.001
+
+ train_batch_size: 2048
+
+ - command line: It should be in the format as '---learning_rate=0.001'
+
+ - parameter dictionaries: It should be a dict, where the key is parameter name and the value is parameter value,
+ e.g. config_dict = {'learning_rate': 0.001}
+
+ Configuration module allows the above three kind of external input format to be used together,
+ the priority order is as following:
+
+ command line > parameter dictionaries > config file
+
+ e.g. If we set learning_rate=0.01 in config file, learning_rate=0.02 in command line,
+ learning_rate=0.03 in parameter dictionaries.
+
+ Finally the learning_rate is equal to 0.02.
+ """
+
+ def __init__(
+ self, model=None, dataset=None, config_file_list=None, config_dict=None
+ ):
+ """
+ Args:
+ model (str/AbstractRecommender): the model name or the model class, default is None, if it is None, config
+ will search the parameter 'model' from the external input as the model name or model class.
+ dataset (str): the dataset name, default is None, if it is None, config will search the parameter 'dataset'
+ from the external input as the dataset name.
+ config_file_list (list of str): the external config file, it allows multiple config files, default is None.
+ config_dict (dict): the external parameter dictionaries, default is None.
+ """
+ self.compatibility_settings()
+ self._init_parameters_category()
+ self.yaml_loader = self._build_yaml_loader()
+ self.file_config_dict = self._load_config_files(config_file_list)
+ self.variable_config_dict = self._load_variable_config_dict(config_dict)
+ self.cmd_config_dict = self._load_cmd_line()
+ self._merge_external_config_dict()
+
+ self.model, self.model_class, self.dataset = self._get_model_and_dataset(
+ model, dataset
+ )
+ self._load_internal_config_dict(self.model, self.model_class, self.dataset)
+ self.final_config_dict = self._get_final_config_dict()
+ self._set_default_parameters()
+ self._init_device()
+
+ def _init_parameters_category(self):
+ self.parameters = dict()
+ self.parameters["General"] = general_arguments
+ self.parameters["Training"] = training_arguments
+ self.parameters["Evaluation"] = evaluation_arguments
+
+ def _build_yaml_loader(self):
+ loader = yaml.FullLoader
+ loader.add_implicit_resolver(
+ "tag:yaml.org,2002:float",
+ re.compile(
+ """^(?:
+ [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
+ |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
+ |\\.[0-9_]+(?:[eE][-+][0-9]+)?
+ |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
+ |[-+]?\\.(?:inf|Inf|INF)
+ |\\.(?:nan|NaN|NAN))$""",
+ re.X,
+ ),
+ list("-+0123456789."),
+ )
+ return loader
+
+ def _convert_config_dict(self, config_dict):
+ """This function convert the str parameters to their original type."""
+ for key in config_dict:
+ param = config_dict[key]
+ if not isinstance(param, str):
+ continue
+ try:
+ value = eval(param)
+ if value is not None and not isinstance(
+ value, (str, int, float, list, tuple, dict, bool)
+ ):
+ value = param
+ except (NameError, SyntaxError, TypeError):
+ if isinstance(param, str):
+ if param.lower() == "true":
+ value = True
+ elif param.lower() == "false":
+ value = False
+ else:
+ value = param
+ else:
+ value = param
+ config_dict[key] = value
+ return config_dict
+
+ def _load_config_files(self, file_list):
+ file_config_dict = dict()
+ if file_list:
+ for file in file_list:
+ with open(file, "r", encoding="utf-8") as f:
+ file_config_dict.update(
+ yaml.load(f.read(), Loader=self.yaml_loader)
+ )
+ return file_config_dict
+
+ def _load_variable_config_dict(self, config_dict):
+ # HyperTuning may set the parameters such as mlp_hidden_size in NeuMF in the format of ['[]', '[]']
+ # then config_dict will receive a str '[]', but indeed it's a list []
+ # temporarily use _convert_config_dict to solve this problem
+ return self._convert_config_dict(config_dict) if config_dict else dict()
+
+ def _load_cmd_line(self):
+ r"""Read parameters from command line and convert it to str."""
+ cmd_config_dict = dict()
+ unrecognized_args = []
+ if "ipykernel_launcher" not in sys.argv[0]:
+ for arg in sys.argv[1:]:
+ if not arg.startswith("--") or len(arg[2:].split("=")) != 2:
+ unrecognized_args.append(arg)
+ continue
+ cmd_arg_name, cmd_arg_value = arg[2:].split("=")
+ if (
+ cmd_arg_name in cmd_config_dict
+ and cmd_arg_value != cmd_config_dict[cmd_arg_name]
+ ):
+ raise SyntaxError(
+ "There are duplicate commend arg '%s' with different value."
+ % arg
+ )
+ else:
+ cmd_config_dict[cmd_arg_name] = cmd_arg_value
+ if len(unrecognized_args) > 0:
+ logger = getLogger()
+ logger.warning(
+ "command line args [{}] will not be used in RecBole".format(
+ " ".join(unrecognized_args)
+ )
+ )
+ cmd_config_dict = self._convert_config_dict(cmd_config_dict)
+ return cmd_config_dict
+
+ def _merge_external_config_dict(self):
+ external_config_dict = dict()
+ external_config_dict.update(self.file_config_dict)
+ external_config_dict.update(self.variable_config_dict)
+ external_config_dict.update(self.cmd_config_dict)
+ self.external_config_dict = external_config_dict
+
+ def _get_model_and_dataset(self, model, dataset):
+ if model is None:
+ try:
+ model = self.external_config_dict["model"]
+ except KeyError:
+ raise KeyError(
+ "model need to be specified in at least one of the these ways: "
+ "[model variable, config file, config dict, command line] "
+ )
+ if not isinstance(model, str):
+ final_model_class = model
+ final_model = model.__name__
+ else:
+ final_model = model
+ final_model_class = get_model(final_model)
+
+ if dataset is None:
+ try:
+ final_dataset = self.external_config_dict["dataset"]
+ except KeyError:
+ raise KeyError(
+ "dataset need to be specified in at least one of the these ways: "
+ "[dataset variable, config file, config dict, command line] "
+ )
+ else:
+ final_dataset = dataset
+
+ return final_model, final_model_class, final_dataset
+
+ def _update_internal_config_dict(self, file):
+ with open(file, "r", encoding="utf-8") as f:
+ config_dict = yaml.load(f.read(), Loader=self.yaml_loader)
+ if config_dict is not None:
+ self.internal_config_dict.update(config_dict)
+ return config_dict
+ def _load_internal_config_dict(self, model, model_class, dataset):
+ current_path = os.path.dirname(os.path.realpath(__file__))
+ overall_init_file = os.path.join(current_path, "../properties/overall.yaml")
+ model_init_file = os.path.join(
+ current_path, "../properties/model/" + model + ".yaml"
+ )
+ dataset_init_file = os.path.join(
+ current_path, "../properties/dataset/" + dataset + ".yaml"
+ )
+
+
+ self.internal_config_dict = dict()
+ for file in [
+ overall_init_file,
+ model_init_file,
+ dataset_init_file,
+ ]:
+ if os.path.isfile(file):
+ config_dict = self._update_internal_config_dict(file)
+ # if file == dataset_init_file:
+ # self.parameters["Dataset"] += [
+ # key
+ # for key in config_dict.keys()
+ # if key not in self.parameters["Dataset"]
+ # ]
+
+ def _get_final_config_dict(self):
+ final_config_dict = dict()
+ final_config_dict.update(self.internal_config_dict)
+ final_config_dict.update(self.external_config_dict)
+ return final_config_dict
+
+ def _set_default_parameters(self):
+ self.final_config_dict["dataset"] = self.dataset
+ self.final_config_dict["model"] = self.model
+
+ metrics = self.final_config_dict["metrics"]
+ if isinstance(metrics, str):
+ self.final_config_dict["metrics"] = [metrics]
+
+ self.final_config_dict["checkpoint_dir"] = self.final_config_dict["checkpoint_dir"] + self.final_config_dict["checkpoint_sub_dir"]
+
+ self.final_config_dict["temp_dir"] = self.final_config_dict['temp_dir'] + self.final_config_dict['temp_sub_dir']
+
+ def _init_device(self):
+ if isinstance(self.final_config_dict["gpu_id"], tuple):
+ self.final_config_dict["gpu_id"] = ",".join(
+ map(str, list(self.final_config_dict["gpu_id"]))
+ )
+ else:
+ self.final_config_dict["gpu_id"] = str(self.final_config_dict["gpu_id"])
+ gpu_id = self.final_config_dict["gpu_id"]
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id
+
+ def __setitem__(self, key, value):
+ if not isinstance(key, str):
+ raise TypeError("index must be a str.")
+ self.final_config_dict[key] = value
+
+ def __getattr__(self, item):
+ if "final_config_dict" not in self.__dict__:
+ raise AttributeError(
+ f"'Config' object has no attribute 'final_config_dict'"
+ )
+ if item in self.final_config_dict:
+ return self.final_config_dict[item]
+ raise AttributeError(f"'Config' object has no attribute '{item}'")
+
+ def __getitem__(self, item):
+ return self.final_config_dict.get(item)
+
+ def __contains__(self, key):
+ if not isinstance(key, str):
+ raise TypeError("index must be a str.")
+ return key in self.final_config_dict
+
+ def __str__(self):
+ args_info = "\n"
+ for category in self.parameters:
+ args_info += set_color(category + " Hyper Parameters:\n", "pink")
+ args_info += "\n".join(
+ [
+ (
+ set_color("{}", "cyan") + " =" + set_color(" {}", "yellow")
+ ).format(arg, value)
+ for arg, value in self.final_config_dict.items()
+ if arg in self.parameters[category]
+ ]
+ )
+ args_info += "\n\n"
+
+ args_info += set_color("Other Hyper Parameters: \n", "pink")
+ args_info += "\n".join(
+ [
+ (set_color("{}", "cyan") + " = " + set_color("{}", "yellow")).format(
+ arg, value
+ )
+ for arg, value in self.final_config_dict.items()
+ if arg
+ not in {_ for args in self.parameters.values() for _ in args}.union(
+ {"model", "dataset", "config_files"}
+ )
+ ]
+ )
+ args_info += "\n\n"
+ return args_info
+
+ def __repr__(self):
+ return self.__str__()
+
+
+ def compatibility_settings(self):
+ import numpy as np
+
+ np.bool = np.bool_
+ np.int = np.int_
+ np.float = np.float_
+ np.complex = np.complex_
+ np.object = np.object_
+ np.str = np.str_
+ np.long = np.int_
+ np.unicode = np.unicode_
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/data/__init__.py b/talkingface-toolkit-main/talkingface/data/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/talkingface-toolkit-main/talkingface/data/dataprocess/ADNeRFmaster_process.py b/talkingface-toolkit-main/talkingface/data/dataprocess/ADNeRFmaster_process.py
new file mode 100644
index 00000000..e69de29b
diff --git a/talkingface-toolkit-main/talkingface/data/dataprocess/__init__.py b/talkingface-toolkit-main/talkingface/data/dataprocess/__init__.py
new file mode 100644
index 00000000..7c7aef87
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/data/dataprocess/__init__.py
@@ -0,0 +1 @@
+from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio, Wav2LipPreprocessForInference
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/data/dataprocess/wav2lip_process.py b/talkingface-toolkit-main/talkingface/data/dataprocess/wav2lip_process.py
new file mode 100644
index 00000000..e2279a42
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/data/dataprocess/wav2lip_process.py
@@ -0,0 +1,253 @@
+import os
+import cv2
+import numpy as np
+import subprocess
+from tqdm import tqdm
+from glob import glob
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from talkingface.utils import face_detection
+import librosa
+import librosa.filters
+from scipy import signal
+from scipy.io import wavfile
+
+class Wav2LipAudio:
+ """This class is used for audio processing of wav2lip
+
+ 这个类提供了从音频到mel谱的方法
+ """
+
+ def __init__(self, config):
+ self.config = config
+ # Conversions
+ self._mel_basis = None
+
+ def load_wav(self, path, sr):
+ return librosa.core.load(path, sr=sr)[0]
+
+ def save_wav(self, wav, path, sr):
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
+ #proposed by @dsmiller
+ wavfile.write(path, sr, wav.astype(np.int16))
+
+ def save_wavenet_wav(self, wav, path, sr):
+ librosa.output.write_wav(path, wav, sr=sr)
+
+ def preemphasis(self, wav, k, preemphasize=True):
+ if preemphasize:
+ return signal.lfilter([1, -k], [1], wav)
+ return wav
+
+ def inv_preemphasis(self, wav, k, inv_preemphasize=True):
+ if inv_preemphasize:
+ return signal.lfilter([1], [1, -k], wav)
+ return wav
+
+ def get_hop_size(self):
+ hop_size = self.config['hop_size']
+ if hop_size is None:
+ assert self.config['frame_shift_ms'] is not None
+ hop_size = int(self.config['frame_shift_ms'] / 1000 * self.config['sample_rate'])
+ return hop_size
+
+ def linearspectrogram(self, wav):
+ D = self._stft(self.preemphasis(wav, self.config['preemphasis'], self.config['preemphasize']))
+ S = self._amp_to_db(np.abs(D)) - self.config['ref_level_db']
+
+ if self.config['signal_normalization']:
+ return self._normalize(S)
+ return S
+
+ def melspectrogram(self, wav):
+ D = self._stft(self.preemphasis(wav, self.config['preemphasis'], self.config['preemphasize']))
+ S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.config['ref_level_db']
+
+ if self.config['signal_normalization']:
+ return self._normalize(S)
+ return S
+
+ def _lws_processor(self):
+ import lws
+ return lws.lws(self.config['n_fft'], self.get_hop_size(), fftsize=self.config['win_size'], mode="speech")
+
+ def _stft(self, y):
+ if self.config['use_lws']:
+ return self._lws_processor().stft(y).T
+ else:
+ return librosa.stft(y=y, n_fft=self.config['n_fft'], hop_length=self.get_hop_size(), win_length=self.config['win_size'])
+
+ ##########################################################
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
+ def num_frames(self, length, fsize, fshift):
+ """Compute number of time frames of spectrogram
+ """
+ pad = (fsize - fshift)
+ if length % fshift == 0:
+ M = (length + pad * 2 - fsize) // fshift + 1
+ else:
+ M = (length + pad * 2 - fsize) // fshift + 2
+ return M
+
+
+ def pad_lr(self, x, fsize, fshift):
+ """Compute left and right padding
+ """
+ M = self.num_frames(len(x), fsize, fshift)
+ pad = (fsize - fshift)
+ T = len(x) + 2 * pad
+ r = (M - 1) * fshift + fsize - T
+ return pad, pad + r
+ ##########################################################
+ #Librosa correct padding
+ def librosa_pad_lr(self, x, fsize, fshift):
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
+
+ def _linear_to_mel(self, spectogram):
+ if self._mel_basis is None:
+ self._mel_basis = self._build_mel_basis()
+ return np.dot(self._mel_basis, spectogram)
+
+ def _build_mel_basis(self):
+ assert self.config['fmax'] <= self.config['sample_rate'] // 2
+ return librosa.filters.mel(sr=self.config['sample_rate'], n_fft=self.config['n_fft'], n_mels=self.config['num_mels'],
+ fmin=self.config['fmin'], fmax=self.config['fmax'])
+
+ def _amp_to_db(self, x):
+ min_level = np.exp(self.config['min_level_db'] / 20 * np.log(10))
+ return 20 * np.log10(np.maximum(min_level, x))
+
+ def _db_to_amp(x):
+ return np.power(10.0, (x) * 0.05)
+
+ def _normalize(self, S):
+ if self.config['allow_clipping_in_normalization']:
+ if self.config['symmetric_mels']:
+ return np.clip((2 * self.config['max_abs_value']) * ((S - self.config['min_level_db']) / (-self.config['min_level_db'])) - self.config['max_abs_value'],
+ -self.config['max_abs_value'], self.config['max_abs_value'])
+ else:
+ return np.clip(self.config['max_abs_value'] * ((S - self.config['min_level_db']) / (-self.config['min_level_db'])), 0, self.config['max_abs_value'])
+
+ assert S.max() <= 0 and S.min() - self.config['min_level_db'] >= 0
+ if self.config['symmetric_mels']:
+ return (2 * self.config['max_abs_value']) * ((S - self.config['min_level_db']) / (-self.config['min_level_db'])) - self.config['max_abs_value']
+ else:
+ return self.config['max_abs_value'] * ((S - self.config['min_level_db']) / (-self.config['min_level_db']))
+
+ def _denormalize(self, D):
+ if self.config['allow_clipping_in_normalization']:
+ if self.config['symmetric_mels']:
+ return (((np.clip(D, -self.config['max_abs_value'],
+ self.config['max_abs_value']) + self.config['max_abs_value']) * -self.config['min_level_db'] / (2 * self.config['max_abs_value']))
+ + self.config['symmetric_mels'])
+ else:
+ return ((np.clip(D, 0, self.config['max_abs_value']) * -self.config['min_level_db'] / self.config['max_abs_value']) + self.config['min_level_db'])
+
+ if self.config['symmetric_mels']:
+ return (((D + self.config['max_abs_value']) * -self.config['min_level_db'] / (2 * self.config['max_abs_value'])) + self.config['min_level_db'])
+ else:
+ return ((D * -self.config['min_level_db'] / self.config['max_abs_value']) + self.config['min_level_db'])
+
+
+
+class Wav2LipPreprocessForInference:
+ """This class is used for preprocessing of wav2lip inference
+
+ face_detect: detect face in the image
+ datagen: generate data for inference
+
+ """
+ def __init__(self, config):
+ self.config = config
+ self.fa = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False,
+ device=config['device'])
+
+ def get_smoothened_boxes(self, boxes, T):
+ for i in range(len(boxes)):
+ if i + T > len(boxes):
+ window = boxes[len(boxes) - T:]
+ else:
+ window = boxes[i : i + T]
+ boxes[i] = np.mean(window, axis=0)
+ return boxes
+
+ def face_detect(self, images):
+ batch_size = self.config['batch_size']
+ while 1:
+ predictions = []
+ try:
+ for i in tqdm(range(0, len(images), batch_size), desc='Running face detection', leave=False):
+ predictions.extend(self.fa.get_detections_for_batch(np.array(images[i:i + batch_size])))
+ except RuntimeError:
+ if batch_size == 1:
+ raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument')
+ batch_size //= 2
+ print('Recovering from OOM error; New batch size: {}'.format(batch_size))
+ continue
+ break
+ results = []
+ pady1, pady2, padx1, padx2 = self.config['pads']
+ for rect, image in zip(predictions, images):
+ if rect is None:
+ cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected.
+ raise ValueError('Face not detected! Ensure the video contains a face in all the frames.')
+
+ y1 = max(0, rect[1] - pady1)
+ y2 = min(image.shape[0], rect[3] + pady2)
+ x1 = max(0, rect[0] - padx1)
+ x2 = min(image.shape[1], rect[2] + padx2)
+
+ results.append([x1, y1, x2, y2])
+ boxes = np.array(results)
+ if not self.config['nosmooth']: boxes = self.get_smoothened_boxes(boxes, T=5)
+ results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)]
+
+ return results
+
+
+ def datagen(self, frames, face_det_results, mels):
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
+
+ if self.config['box'][0] == -1:
+ if not self.config['static']:
+ face_det_results = self.face_detect(frames) # BGR2RGB for CNN face detection
+ else:
+ face_det_results = self.face_detect([frames[0]])
+ else:
+ print('Using the specified bounding box instead of face detection...')
+ y1, y2, x1, x2 = self.config['box']
+ face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames]
+
+ for i, m in enumerate(mels):
+ idx = 0 if self.config['static'] else i%len(frames)
+ frame_to_save = frames[idx].copy()
+ face, coords = face_det_results[idx].copy()
+
+ face = cv2.resize(face, (self.config['img_size'], self.config['img_size']))
+
+ img_batch.append(face)
+ mel_batch.append(m)
+ frame_batch.append(frame_to_save)
+ coords_batch.append(coords)
+
+ if len(img_batch) >= self.config['wav2lip_batch_size']:
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
+
+ img_masked = img_batch.copy()
+ img_masked[:, self.config['img_size']//2:] = 0
+
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
+
+ yield img_batch, mel_batch, frame_batch, coords_batch
+ img_batch, mel_batch, frame_batch, coords_batch = [], [], [], []
+
+ if len(img_batch) > 0:
+ img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
+
+ img_masked = img_batch.copy()
+ img_masked[:, self.config['img_size']//2:] = 0
+
+ img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
+ mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
+
+ yield img_batch, mel_batch, frame_batch, coords_batch
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/data/dataset/AD_NeRF_master_dataset.py b/talkingface-toolkit-main/talkingface/data/dataset/AD_NeRF_master_dataset.py
new file mode 100644
index 00000000..155860b4
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/data/dataset/AD_NeRF_master_dataset.py
@@ -0,0 +1,282 @@
+import cv2
+import numpy as np
+import face_alignment
+from skimage import io
+import torch
+import torch.nn.functional as F
+import json
+import os
+from sklearn.neighbors import NearestNeighbors
+from pathlib import Path
+import argparse
+
+
+def euler2rot(euler_angle):
+ batch_size = euler_angle.shape[0]
+ theta = euler_angle[:, 0].reshape(-1, 1, 1)
+ phi = euler_angle[:, 1].reshape(-1, 1, 1)
+ psi = euler_angle[:, 2].reshape(-1, 1, 1)
+ one = torch.ones((batch_size, 1, 1), dtype=torch.float32,
+ device=euler_angle.device)
+ zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32,
+ device=euler_angle.device)
+ rot_x = torch.cat((
+ torch.cat((one, zero, zero), 1),
+ torch.cat((zero, theta.cos(), theta.sin()), 1),
+ torch.cat((zero, -theta.sin(), theta.cos()), 1),
+ ), 2)
+ rot_y = torch.cat((
+ torch.cat((phi.cos(), zero, -phi.sin()), 1),
+ torch.cat((zero, one, zero), 1),
+ torch.cat((phi.sin(), zero, phi.cos()), 1),
+ ), 2)
+ rot_z = torch.cat((
+ torch.cat((psi.cos(), -psi.sin(), zero), 1),
+ torch.cat((psi.sin(), psi.cos(), zero), 1),
+ torch.cat((zero, zero, one), 1)
+ ), 2)
+ return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--id', type=str,
+ default='obama', help='identity of target person')
+parser.add_argument('--step', type=int,
+ default=0, help='step for running')
+
+args = parser.parse_args()
+id = args.id
+vid_file = os.path.join('dataset', 'vids', id + '.mp4')
+if not os.path.isfile(vid_file):
+ print('no video')
+ exit()
+
+id_dir = os.path.join('dataset', id)
+Path(id_dir).mkdir(parents=True, exist_ok=True)
+ori_imgs_dir = os.path.join('dataset', id, 'ori_imgs')
+Path(ori_imgs_dir).mkdir(parents=True, exist_ok=True)
+parsing_dir = os.path.join(id_dir, 'parsing')
+Path(parsing_dir).mkdir(parents=True, exist_ok=True)
+head_imgs_dir = os.path.join('dataset', id, 'head_imgs')
+Path(head_imgs_dir).mkdir(parents=True, exist_ok=True)
+com_imgs_dir = os.path.join('dataset', id, 'com_imgs')
+Path(com_imgs_dir).mkdir(parents=True, exist_ok=True)
+
+running_step = args.step
+
+# # Step 0: extract wav & deepspeech feature, better run in terminal to parallel with
+# below commands since this may take a few minutes
+if running_step == 0:
+ print('--- Step0: extract deepspeech feature ---')
+ wav_file = os.path.join(id_dir, 'aud.wav')
+ extract_wav_cmd = 'ffmpeg -i ' + vid_file + ' -f wav -ar 16000 ' + wav_file
+ os.system(extract_wav_cmd)
+ extract_ds_cmd = 'python data_util/deepspeech_features/extract_ds_features.py --input=' + id_dir
+ os.system(extract_ds_cmd)
+ exit()
+
+# Step 1: extract images
+if running_step == 1:
+ print('--- Step1: extract images from vids ---')
+ cap = cv2.VideoCapture(vid_file)
+ frame_num = 0
+ while (True):
+ _, frame = cap.read()
+ if frame is None:
+ break
+ cv2.imwrite(os.path.join(ori_imgs_dir, str(frame_num) + '.jpg'), frame)
+ frame_num = frame_num + 1
+ cap.release()
+ exit()
+
+# Step 2: detect lands
+if running_step == 2:
+ print('--- Step 2: detect landmarks ---')
+ fa = face_alignment.FaceAlignment(
+ face_alignment.LandmarksType._2D, flip_input=False)
+ for image_path in os.listdir(ori_imgs_dir):
+ if image_path.endswith('.jpg'):
+ input = io.imread(os.path.join(ori_imgs_dir, image_path))[:, :, :3]
+ preds = fa.get_landmarks(input)
+ if len(preds) > 0:
+ lands = preds[0].reshape(-1, 2)[:, :2]
+ np.savetxt(os.path.join(ori_imgs_dir, image_path[:-3] + 'lms'), lands, '%f')
+
+max_frame_num = 100000
+valid_img_ids = []
+for i in range(max_frame_num):
+ if os.path.isfile(os.path.join(ori_imgs_dir, str(i) + '.lms')):
+ valid_img_ids.append(i)
+valid_img_num = len(valid_img_ids)
+tmp_img = cv2.imread(os.path.join(ori_imgs_dir, str(valid_img_ids[0]) + '.jpg'))
+h, w = tmp_img.shape[0], tmp_img.shape[1]
+
+# Step 3: face parsing
+if running_step == 3:
+ print('--- Step 3: face parsing ---')
+ face_parsing_cmd = 'python data_util/face_parsing/test.py --respath=dataset/' + \
+ id + '/parsing --imgpath=dataset/' + id + '/ori_imgs'
+ os.system(face_parsing_cmd)
+
+# Step 4: extract bc image
+if running_step == 4:
+ print('--- Step 4: extract background image ---')
+ sel_ids = np.array(valid_img_ids)[np.arange(0, valid_img_num, 20)]
+ all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose()
+ distss = []
+ for i in sel_ids:
+ parse_img = cv2.imread(os.path.join(id_dir, 'parsing', str(i) + '.png'))
+ bg = (parse_img[..., 0] == 255) & (
+ parse_img[..., 1] == 255) & (parse_img[..., 2] == 255)
+ fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0)
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
+ dists, _ = nbrs.kneighbors(all_xys)
+ distss.append(dists)
+ distss = np.stack(distss)
+ print(distss.shape)
+ max_dist = np.max(distss, 0)
+ max_id = np.argmax(distss, 0)
+ bc_pixs = max_dist > 5
+ bc_pixs_id = np.nonzero(bc_pixs)
+ bc_ids = max_id[bc_pixs]
+ imgs = []
+ num_pixs = distss.shape[1]
+ for i in sel_ids:
+ img = cv2.imread(os.path.join(ori_imgs_dir, str(i) + '.jpg'))
+ imgs.append(img)
+ imgs = np.stack(imgs).reshape(-1, num_pixs, 3)
+ bc_img = np.zeros((h * w, 3), dtype=np.uint8)
+ bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :]
+ bc_img = bc_img.reshape(h, w, 3)
+ max_dist = max_dist.reshape(h, w)
+ bc_pixs = max_dist > 5
+ bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose()
+ fg_xys = np.stack(np.nonzero(bc_pixs)).transpose()
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys)
+ distances, indices = nbrs.kneighbors(bg_xys)
+ bg_fg_xys = fg_xys[indices[:, 0]]
+ print(fg_xys.shape)
+ print(np.max(bg_fg_xys), np.min(bg_fg_xys))
+ bc_img[bg_xys[:, 0], bg_xys[:, 1],
+ :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :]
+ cv2.imwrite(os.path.join(id_dir, 'bc.jpg'), bc_img)
+
+# Step 5: save training images
+if running_step == 5:
+ print('--- Step 5: save training images ---')
+ bc_img = cv2.imread(os.path.join(id_dir, 'bc.jpg'))
+ for i in valid_img_ids:
+ parsing_img = cv2.imread(os.path.join(parsing_dir, str(i) + '.png'))
+ head_part = (parsing_img[:, :, 0] == 255) & (
+ parsing_img[:, :, 1] == 0) & (parsing_img[:, :, 2] == 0)
+ bc_part = (parsing_img[:, :, 0] == 255) & (
+ parsing_img[:, :, 1] == 255) & (parsing_img[:, :, 2] == 255)
+ img = cv2.imread(os.path.join(ori_imgs_dir, str(i) + '.jpg'))
+ img[bc_part] = bc_img[bc_part]
+ cv2.imwrite(os.path.join(com_imgs_dir, str(i) + '.jpg'), img)
+ img[~head_part] = bc_img[~head_part]
+ cv2.imwrite(os.path.join(head_imgs_dir, str(i) + '.jpg'), img)
+
+# Step 6: estimate head pose
+if running_step == 6:
+ print('--- Estimate Head Pose ---')
+ est_pose_cmd = 'python data_util/face_tracking/face_tracker.py --idname=' + \
+ id + ' --img_h=' + str(h) + ' --img_w=' + str(w) + \
+ ' --frame_num=' + str(max_frame_num)
+ os.system(est_pose_cmd)
+ exit()
+
+# Step 7: save transform param & write config file
+if running_step == 7:
+ print('--- Step 7: Save Transform Param ---')
+ params_dict = torch.load(os.path.join(id_dir, 'track_params.pt'))
+ focal_len = params_dict['focal']
+ euler_angle = params_dict['euler']
+ trans = params_dict['trans'] / 10.0
+ valid_num = euler_angle.shape[0]
+ train_val_split = int(valid_num * 10 / 11)
+ train_ids = torch.arange(0, train_val_split)
+ val_ids = torch.arange(train_val_split, valid_num)
+ rot = euler2rot(euler_angle)
+ rot_inv = rot.permute(0, 2, 1)
+ trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2))
+ pose = torch.eye(4, dtype=torch.float32)
+ save_ids = ['train', 'val']
+ train_val_ids = [train_ids, val_ids]
+ mean_z = -float(torch.mean(trans[:, 2]).item())
+ for i in range(2):
+ transform_dict = dict()
+ transform_dict['focal_len'] = float(focal_len[0])
+ transform_dict['cx'] = float(w / 2.0)
+ transform_dict['cy'] = float(h / 2.0)
+ transform_dict['frames'] = []
+ ids = train_val_ids[i]
+ save_id = save_ids[i]
+ for i in ids:
+ i = i.item()
+ frame_dict = dict()
+ frame_dict['img_id'] = int(valid_img_ids[i])
+ frame_dict['aud_id'] = int(valid_img_ids[i])
+ pose[:3, :3] = rot_inv[i]
+ pose[:3, 3] = trans_inv[i, :, 0]
+ frame_dict['transform_matrix'] = pose.numpy().tolist()
+ lms = np.loadtxt(os.path.join(
+ ori_imgs_dir, str(valid_img_ids[i]) + '.lms'))
+ min_x, max_x = np.min(lms, 0)[0], np.max(lms, 0)[0]
+ cx = int((min_x + max_x) / 2.0)
+ cy = int(lms[27, 1])
+ h_w = int((max_x - cx) * 1.5)
+ h_h = int((lms[8, 1] - cy) * 1.15)
+ rect_x = cx - h_w
+ rect_y = cy - h_h
+ if rect_x < 0:
+ rect_x = 0
+ if rect_y < 0:
+ rect_y = 0
+ rect_w = min(w - 1 - rect_x, 2 * h_w)
+ rect_h = min(h - 1 - rect_y, 2 * h_h)
+ rect = np.array((rect_x, rect_y, rect_w, rect_h), dtype=np.int32)
+ frame_dict['face_rect'] = rect.tolist()
+ transform_dict['frames'].append(frame_dict)
+ with open(os.path.join(id_dir, 'transforms_' + save_id + '.json'), 'w') as fp:
+ json.dump(transform_dict, fp, indent=2, separators=(',', ': '))
+
+ dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+ testskip = int(val_ids.shape[0] / 7)
+
+ HeadNeRF_config_file = os.path.join(id_dir, 'HeadNeRF_config.txt')
+ with open(HeadNeRF_config_file, 'w') as file:
+ file.write('expname = ' + id + '_head\n')
+ file.write('datadir = ' + os.path.join(dir_path, 'dataset', id) + '\n')
+ file.write('basedir = ' + os.path.join(dir_path,
+ 'dataset', id, 'logs') + '\n')
+ file.write('near = ' + str(mean_z - 0.2) + '\n')
+ file.write('far = ' + str(mean_z + 0.4) + '\n')
+ file.write('testskip = ' + str(testskip) + '\n')
+ Path(os.path.join(dir_path, 'dataset', id, 'logs', id + '_head')
+ ).mkdir(parents=True, exist_ok=True)
+
+ ComNeRF_config_file = os.path.join(id_dir, 'TorsoNeRF_config.txt')
+ with open(ComNeRF_config_file, 'w') as file:
+ file.write('expname = ' + id + '_com\n')
+ file.write('datadir = ' + os.path.join(dir_path, 'dataset', id) + '\n')
+ file.write('basedir = ' + os.path.join(dir_path,
+ 'dataset', id, 'logs') + '\n')
+ file.write('near = ' + str(mean_z - 0.2) + '\n')
+ file.write('far = ' + str(mean_z + 0.4) + '\n')
+ file.write('testskip = ' + str(testskip) + '\n')
+ Path(os.path.join(dir_path, 'dataset', id, 'logs', id + '_com')
+ ).mkdir(parents=True, exist_ok=True)
+
+ ComNeRFTest_config_file = os.path.join(id_dir, 'TorsoNeRFTest_config.txt')
+ with open(ComNeRFTest_config_file, 'w') as file:
+ file.write('expname = ' + id + '_com\n')
+ file.write('datadir = ' + os.path.join(dir_path, 'dataset', id) + '\n')
+ file.write('basedir = ' + os.path.join(dir_path,
+ 'dataset', id, 'logs') + '\n')
+ file.write('near = ' + str(mean_z - 0.2) + '\n')
+ file.write('far = ' + str(mean_z + 0.4) + '\n')
+ file.write('with_test = ' + str(1) + '\n')
+ file.write('test_pose_file = transforms_val.json' + '\n')
+
+ print(id + ' data processed done!')
diff --git a/talkingface-toolkit-main/talkingface/data/dataset/__init__.py b/talkingface-toolkit-main/talkingface/data/dataset/__init__.py
new file mode 100644
index 00000000..4d598341
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/data/dataset/__init__.py
@@ -0,0 +1,3 @@
+from talkingface.data.dataset.dataset import Dataset
+from talkingface.data.dataset.meta_portrait_base_dataset import *
+# from talkingface.data.dataset.wav2lip_dataset import Wav2LipDataset
diff --git a/talkingface-toolkit-main/talkingface/data/dataset/dataset.py b/talkingface-toolkit-main/talkingface/data/dataset/dataset.py
new file mode 100644
index 00000000..2de27bd7
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/data/dataset/dataset.py
@@ -0,0 +1,25 @@
+import torch
+
+class Dataset(torch.utils.data.Dataset):
+ def __init__(self, config, datasplit):
+
+ """
+ args: datasplit: str, 'train', 'val' or 'test'(这个参数必须要有, 提前将数据集划分为train, val和test三个部分,
+ 具体参数形式可以自己定,只要在你的dataset子类中可以获取到数据就可以,
+ 对应的配置文件的参数为:train_filelist, val_filelist和test_filelist)
+
+ """
+
+ self.config = config
+ self.split = datasplit
+
+ def __getitem__(self):
+
+ """
+ Returns:
+ data: dict, 必须是一个字典格式, 具体数据解析在model文件里解析
+
+ """
+
+
+ raise NotImplementedError
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/data/dataset/wav2lip_dataset.py b/talkingface-toolkit-main/talkingface/data/dataset/wav2lip_dataset.py
new file mode 100644
index 00000000..e52ae28d
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/data/dataset/wav2lip_dataset.py
@@ -0,0 +1,158 @@
+from os.path import dirname, join, basename, isfile
+from tqdm import tqdm
+from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio
+import python_speech_features
+
+import torch
+from torch import nn
+from torch import optim
+import torch.backends.cudnn as cudnn
+from torch.utils import data as data_utils
+import numpy as np
+import librosa
+from talkingface.data.dataset.dataset import Dataset
+
+from glob import glob
+
+import os, random, cv2, argparse
+
+class Wav2LipDataset(Dataset):
+ def __init__(self, config, datasplit):
+ super().__init__(config, datasplit)
+ self.all_videos = self.get_image_list(self.config['preprocessed_root'], datasplit)
+ self.audio_processor = Wav2LipAudio(self.config)
+
+
+ def get_image_list(self, data_root, split_path):
+ filelist = []
+
+ with open(split_path) as f:
+ for line in f:
+ line = line.strip()
+ if ' ' in line: line = line.split()[0]
+ filelist.append(os.path.join(data_root, line))
+
+ return filelist
+
+
+ def get_frame_id(self, frame):
+ return int(basename(frame).split('.')[0])
+
+ def get_window(self, start_frame):
+ start_id = self.get_frame_id(start_frame)
+ vidname = dirname(start_frame)
+
+ window_fnames = []
+ for frame_id in range(start_id, start_id + self.config['syncnet_T']):
+ frame = join(vidname, '{}.jpg'.format(frame_id))
+ if not isfile(frame):
+ return None
+ window_fnames.append(frame)
+ return window_fnames
+
+ def read_window(self, window_fnames):
+ if window_fnames is None: return None
+ window = []
+ for fname in window_fnames:
+ img = cv2.imread(fname)
+ if img is None:
+ return None
+ try:
+ img = cv2.resize(img, (self.config['img_size'], self.config['img_size']))
+ except Exception as e:
+ return None
+
+ window.append(img)
+
+ return window
+
+ def crop_audio_window(self, spec, start_frame):
+ if type(start_frame) == int:
+ start_frame_num = start_frame
+ else:
+ start_frame_num = self.get_frame_id(start_frame) # 0-indexing ---> 1-indexing
+ start_idx = int(80. * (start_frame_num / float(self.config['fps'])))
+
+ end_idx = start_idx + self.config['syncnet_mel_step_size']
+
+ return spec[start_idx : end_idx, :]
+
+ def get_segmented_mels(self, spec, start_frame):
+ mels = []
+ assert self.config['syncnet_T'] == 5
+ start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing
+ if start_frame_num - 2 < 0: return None
+ for i in range(start_frame_num, start_frame_num + self.config['syncnet_T']):
+ m = self.crop_audio_window(spec, i - 2)
+ if m.shape[0] != self.config['syncnet_mel_step_size']:
+ return None
+ mels.append(m.T)
+
+ mels = np.asarray(mels)
+
+ return mels
+
+ def prepare_window(self, window):
+ # 3 x T x H x W
+ x = np.asarray(window) / 255.
+ x = np.transpose(x, (3, 0, 1, 2))
+
+ return x
+
+ def __len__(self):
+ return len(self.all_videos)
+
+ def __getitem__(self, idx):
+ while 1:
+ idx = random.randint(0, len(self.all_videos) - 1)
+ vidname = self.all_videos[idx]
+ img_names = list(glob(join(vidname, '*.jpg')))
+ if len(img_names) <= 3 * self.config['syncnet_T']:
+ continue
+
+ img_name = random.choice(img_names)
+ wrong_img_name = random.choice(img_names)
+ while wrong_img_name == img_name:
+ wrong_img_name = random.choice(img_names)
+
+ window_fnames = self.get_window(img_name)
+ wrong_window_fnames = self.get_window(wrong_img_name)
+ if window_fnames is None or wrong_window_fnames is None:
+ continue
+
+ window = self.read_window(window_fnames)
+ if window is None:
+ continue
+
+ wrong_window = self.read_window(wrong_window_fnames)
+ if wrong_window is None:
+ continue
+
+ try:
+ wavpath = join(vidname, "audio.wav")
+ wav = self.audio_processor.load_wav(wavpath, self.config['sample_rate'])
+ orig_mel = self.audio_processor.melspectrogram(wav).T
+ except Exception as e:
+ continue
+
+ mel = self.crop_audio_window(orig_mel.copy(), img_name)
+
+ if (mel.shape[0] != self.config['syncnet_mel_step_size']):
+ continue
+
+ indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name)
+ if indiv_mels is None: continue
+
+ window = self.prepare_window(window)
+ y = window.copy()
+ window[:, :, window.shape[2]//2:] = 0.
+
+ wrong_window = self.prepare_window(wrong_window)
+ x = np.concatenate([window, wrong_window], axis=0)
+
+ x = torch.FloatTensor(x)
+ mel = torch.FloatTensor(mel.T).unsqueeze(0)
+ # breakpoint()
+ indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1)
+ y = torch.FloatTensor(y)
+ return {"input_frames":x, "indiv_mels":indiv_mels, "mels":mel, "gt":y}
diff --git a/talkingface-toolkit-main/talkingface/evaluator/__init__.py b/talkingface-toolkit-main/talkingface/evaluator/__init__.py
new file mode 100644
index 00000000..81d4f2ee
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/evaluator/__init__.py
@@ -0,0 +1,5 @@
+from talkingface.evaluator.metric_models import *
+from talkingface.evaluator.metrics import *
+from talkingface.evaluator.register import *
+from talkingface.evaluator.evaluator import *
+from talkingface.evaluator.meta_portrait_base_inference import *
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/evaluator/base_metric.py b/talkingface-toolkit-main/talkingface/evaluator/base_metric.py
new file mode 100644
index 00000000..81e15721
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/evaluator/base_metric.py
@@ -0,0 +1,98 @@
+import torch
+from talkingface.utils import EvaluatorType
+
+class AbstractMetric(object):
+ """:class:`AbstractMetric` is the base object of all metrics. If you want to
+ implement a metric, you should inherit this class.
+
+ Args:
+ config (Config): the config of evaluator.
+ """
+
+ smaller = False
+
+ def __init__(self, config):
+ self.decimal_place = config["metric_decimal_place"]
+
+ def calculate_metric(self, dataobject):
+ """Get the dictionary of a metric.
+
+ Args:
+ dataobject: (dict): it contains all the information needed to calculate metrics.
+
+ Returns:
+ dict: such as ``{LSE-C': 0.0000}``
+ """
+ raise NotImplementedError("Method [calculate_metric] should be implemented.")
+
+
+class SyncMetric(AbstractMetric):
+ """Base class for all Sync metrics. If you want to implement a sync metric, you can inherit this class.
+ """
+
+ metric_type = EvaluatorType.SYNC
+ metric_need = ["generated_video"]
+ def __init__(self, config):
+ super(SyncMetric, self).__init__(config)
+
+ def get_videolist(self, dataobject):
+ """Get the list of videos.
+
+ Args:
+ dataobject(DataStruct): (dict): it contains all the information needed to calculate metrics.
+
+ Returns:
+ list: a list of videos.
+ """
+ return dataobject["generated_video"]
+
+ def metric_info(self, dataobject):
+ """Calculate the value of the metric.
+
+ Args:
+ dataobject(DataStruct): it contains all the information needed to calculate metrics.
+
+ Returns:
+ dict: {"LSE-C": LSE_C, "LSE-D": LSE_D}
+ """
+ raise NotImplementedError("Method [metric_info] should be implemented.")
+
+class VideoQMetric(AbstractMetric):
+ """Base class for all Video Quality metrics. If you want to implement a Video Quality metric, you can inherit this class.
+ """
+
+ metric_type = EvaluatorType.VIDEOQ
+ def __init__(self, config):
+ super(VideoQMetric, self).__init__(config)
+
+ def get_videopair(self, dataobject):
+ return list(zip(dataobject["generated_video"], dataobject["real_video"]))
+
+ def metric_info(self, dataobject):
+ """Calculate the value of the metric.
+
+ Args:
+ dataobject(DataStruct): (dict): it contains all the information needed to calculate metrics.
+
+ Returns:
+ float: the value of the metric.
+ """
+ raise NotImplementedError("Method [metric_info] should be implemented.")
+
+# class AudioQMetric(AbstractMetric):
+# """Base class for all Audio Quality metrics. If you want to implement a Audio Quality metric, you can inherit this class.
+# """
+# def __init__(self, config):
+# super(SyncMetric, self).__init__(config)
+
+# def metric_info(self, dataobject):
+# """Calculate the value of the metric.
+
+# Args:
+# dataobject(DataStruct): it contains all the information needed to calculate metrics.
+
+# Returns:
+# float: the value of the metric.
+# """
+# raise NotImplementedError("Method [metric_info] should be implemented.")
+
diff --git a/talkingface-toolkit-main/talkingface/evaluator/evaluator.py b/talkingface-toolkit-main/talkingface/evaluator/evaluator.py
new file mode 100644
index 00000000..d2592bee
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/evaluator/evaluator.py
@@ -0,0 +1,22 @@
+from talkingface.evaluator.register import metrics_dict
+
+class Evaluator(object):
+ def __init__(self, config):
+ self.config = config
+ self.metrics = [metric.lower() for metric in self.config["metrics"]]
+ self.metric_class = {}
+
+ for metric in self.metrics:
+ if metric not in metrics_dict:
+ raise ValueError(f"Metric '{metric}' is not defined.")
+ self.metric_class[metric] = metrics_dict[metric](self.config)
+
+ def evaluate(self, datadict):
+
+ result_dict = {}
+
+ for metric in self.metrics:
+ metric_val = self.metric_class[metric].calculate_metric(datadict)
+ result_dict[metric] = metric_val
+
+ return result_dict
diff --git a/talkingface-toolkit-main/talkingface/evaluator/metric_models.py b/talkingface-toolkit-main/talkingface/evaluator/metric_models.py
new file mode 100644
index 00000000..3a38ccba
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/evaluator/metric_models.py
@@ -0,0 +1,105 @@
+import torch
+import torch.nn as nn
+
+class S(nn.Module):
+ def __init__(self, num_layers_in_fc_layers = 1024):
+ super(S, self).__init__()
+
+ self.__nFeatures__ = 24
+ self.__nChs__ = 32
+ self.__midChs__ = 32
+
+ self.netcnnaud = nn.Sequential(
+ nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=(1,1), stride=(1,1)),
+
+ nn.Conv2d(64, 192, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
+ nn.BatchNorm2d(192),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=(3,3), stride=(1,2)),
+
+ nn.Conv2d(192, 384, kernel_size=(3,3), padding=(1,1)),
+ nn.BatchNorm2d(384),
+ nn.ReLU(inplace=True),
+
+ nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)),
+ nn.BatchNorm2d(256),
+ nn.ReLU(inplace=True),
+
+ nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)),
+ nn.BatchNorm2d(256),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)),
+
+ nn.Conv2d(256, 512, kernel_size=(5,4), padding=(0,0)),
+ nn.BatchNorm2d(512),
+ nn.ReLU(),
+ )
+
+ self.netfcaud = nn.Sequential(
+ nn.Linear(512, 512),
+ nn.BatchNorm1d(512),
+ nn.ReLU(),
+ nn.Linear(512, num_layers_in_fc_layers),
+ )
+
+ self.netfclip = nn.Sequential(
+ nn.Linear(512, 512),
+ nn.BatchNorm1d(512),
+ nn.ReLU(),
+ nn.Linear(512, num_layers_in_fc_layers),
+ )
+
+ self.netcnnlip = nn.Sequential(
+ nn.Conv3d(3, 96, kernel_size=(5,7,7), stride=(1,2,2), padding=0),
+ nn.BatchNorm3d(96),
+ nn.ReLU(inplace=True),
+ nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)),
+
+ nn.Conv3d(96, 256, kernel_size=(1,5,5), stride=(1,2,2), padding=(0,1,1)),
+ nn.BatchNorm3d(256),
+ nn.ReLU(inplace=True),
+ nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)),
+
+ nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)),
+ nn.BatchNorm3d(256),
+ nn.ReLU(inplace=True),
+
+ nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)),
+ nn.BatchNorm3d(256),
+ nn.ReLU(inplace=True),
+
+ nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)),
+ nn.BatchNorm3d(256),
+ nn.ReLU(inplace=True),
+ nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)),
+
+ nn.Conv3d(256, 512, kernel_size=(1,6,6), padding=0),
+ nn.BatchNorm3d(512),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward_aud(self, x):
+ mid = self.netcnnaud(x); # N x ch x 24 x M
+ mid = mid.view((mid.size()[0], -1)); # N x (ch x 24)
+ out = self.netfcaud(mid)
+
+ return out
+
+ def forward_lip(self, x):
+
+ mid = self.netcnnlip(x);
+ mid = mid.view((mid.size()[0], -1)); # N x (ch x 24)
+ out = self.netfclip(mid)
+
+ return out
+
+ def forward_lipfeat(self, x):
+
+ mid = self.netcnnlip(x)
+ out = mid.view((mid.size()[0], -1)); # N x (ch x 24)
+
+ return out
+
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/evaluator/metrics.py b/talkingface-toolkit-main/talkingface/evaluator/metrics.py
new file mode 100644
index 00000000..eb2692c7
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/evaluator/metrics.py
@@ -0,0 +1,243 @@
+import torch
+import numpy
+import time, pdb, argparse, subprocess, os, math, glob
+import cv2
+import python_speech_features
+from tqdm import tqdm
+from scipy import signal
+from scipy.io import wavfile
+from talkingface.evaluator.metric_models import *
+from shutil import rmtree
+from skimage.metrics import structural_similarity as ssim
+from talkingface.evaluator.base_metric import AbstractMetric, SyncMetric, VideoQMetric
+from talkingface.utils.logger import set_color
+
+class LSE(SyncMetric):
+
+ '''
+ '''
+ def __init__(self, config, num_layers_in_fc_layers = 1024):
+ super(LSE, self).__init__(config)
+ self.config = config
+ self.syncnet = S(num_layers_in_fc_layers = num_layers_in_fc_layers)
+
+ def metric_info(self, videofile):
+ self.loadParameters(self.config['lse_checkpoint_path'])
+ self.syncnet.to(self.config["device"])
+ self.syncnet.eval()
+
+ if os.path.exists(os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'])):
+ rmtree(os.path.join(self.config['temp_dir'], self.config['lse_reference_dir']))
+
+ os.makedirs(os.path.join(self.config['temp_dir'], self.config['lse_reference_dir']))
+
+ command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'],'%06d.jpg')))
+ output = subprocess.call(command, shell=True, stdout=None)
+
+ command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'],'audio.wav')))
+ output = subprocess.call(command, shell=True, stdout=None)
+
+ images = []
+
+ flist = glob.glob(os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'],'*.jpg'))
+ flist.sort()
+
+ for fname in flist:
+ img_input = cv2.imread(fname)
+ img_input = cv2.resize(img_input, (224, 224))
+ images.append(img_input)
+
+ im = numpy.stack(images, axis=3)
+ im = numpy.expand_dims(im, axis=0)
+ im = numpy.transpose(im, (0, 3, 4, 1, 2))
+
+ imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float())
+
+ sample_rate, audio = wavfile.read(os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'],'audio.wav'))
+ mfcc = zip(*python_speech_features.mfcc(audio,sample_rate))
+ mfcc = numpy.stack([numpy.array(i) for i in mfcc])
+
+ cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0)
+ cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float())
+
+ # ========== ==========
+ # Check audio and video input length
+ # ========== ==========
+
+ #if (float(len(audio))/16000) != (float(len(images))/25) :
+ # print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25))
+
+ min_length = min(len(images),math.floor(len(audio)/640))
+
+ # ========== ==========
+ # Generate video and audio feats
+ # ========== ==========
+
+ lastframe = min_length-5
+ im_feat = []
+ cc_feat = []
+
+ tS = time.time()
+ for i in range(0,lastframe,self.config['evaluate_batch_size']):
+
+ im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe, i+self.config['evaluate_batch_size'])) ]
+ im_in = torch.cat(im_batch,0)
+ im_out = self.syncnet.forward_lip(im_in.cuda())
+ im_feat.append(im_out.data.cpu())
+
+ cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe, i+self.config['evaluate_batch_size'])) ]
+ cc_in = torch.cat(cc_batch,0)
+ cc_out = self.syncnet.forward_aud(cc_in.cuda())
+ cc_feat.append(cc_out.data.cpu())
+
+ im_feat = torch.cat(im_feat,0)
+ cc_feat = torch.cat(cc_feat,0)
+
+ # ========== ==========
+ # Compute offset
+ # ========== ==========
+
+ #print('Compute time %.3f sec.' % (time.time()-tS))
+
+ dists = self.calc_pdist(im_feat,cc_feat,vshift=self.config['vshift'])
+ mdist = torch.mean(torch.stack(dists,1),1)
+
+ minval, minidx = torch.min(mdist,0)
+
+ offset = self.config['vshift']-minidx
+ conf = torch.median(mdist) - minval
+
+ fdist = numpy.stack([dist[minidx].numpy() for dist in dists])
+ # fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15)
+ fconf = torch.median(mdist).numpy() - fdist
+ fconfm = signal.medfilt(fconf,kernel_size=9)
+
+ numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format})
+
+ return conf.numpy(), minval.numpy()
+
+
+
+ def calc_pdist(self, feat1, feat2, vshift=10):
+
+ win_size = vshift*2+1
+
+ feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift))
+
+ dists = []
+
+ for i in range(0,len(feat1)):
+
+ dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:]))
+
+ return dists
+
+ def calculate_metric(self, dataobject):
+ video_list = self.get_videolist(dataobject)
+
+ LSE_dict = {}
+
+ iter_data = (
+ tqdm(
+ video_list,
+ total=len(video_list),
+ desc=set_color("calculate for lip-audio sync", "yellow")
+ )
+ if self.config['show_progress']
+ else video_list
+ )
+
+ for video in iter_data:
+ LSE_C, LSE_D = self.metric_info(video)
+ if "LSE_C" not in LSE_dict:
+ LSE_dict["LSE_C"] = [LSE_C]
+ else:
+ LSE_dict["LSE_C"].append(LSE_C)
+ if "LSE_D" not in LSE_dict:
+ LSE_dict["LSE_D"] = [LSE_D]
+ else:
+ LSE_dict["LSE_D"].append(LSE_D)
+
+ return {"LSE-C: {}".format(sum(LSE_dict["LSE_C"])/len(LSE_dict["LSE_C"])), "LSE-D: {}".format(sum(LSE_dict["LSE_D"])/len(LSE_dict["LSE_D"]))}
+
+ def loadParameters(self, path):
+ loaded_state = torch.load(path, map_location=lambda storage, loc: storage)
+
+ self_state = self.syncnet.state_dict()
+
+ for name, param in loaded_state.items():
+
+ self_state[name].copy_(param)
+
+
+class SSIM(VideoQMetric):
+
+ metric_need = ["generated_video", "real_video"]
+
+ def __init__(self, config):
+ super(SSIM, self).__init__(config)
+ self.config = config
+
+ def metric_info(self, g_videofile, r_videofile):
+ g_frames = []
+ g_video = cv2.VideoCapture(g_videofile)
+ while g_video.isOpened():
+ ret, frame = g_video.read()
+ if not ret:
+ break
+ frame = cv2.resize(frame, (224, 224))
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ g_frames.append(gray)
+ g_video.release()
+
+ r_frames = []
+ r_video = cv2.VideoCapture(r_videofile)
+ while r_video.isOpened():
+ ret, frame = r_video.read()
+ if not ret:
+ break
+ frame = cv2.resize(frame, (224, 224))
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
+ r_frames.append(gray)
+ r_video.release()
+
+ min_frames = min(len(g_frames), len(r_frames))
+
+ g_frames = g_frames[:min_frames]
+ r_frames = r_frames[:min_frames]
+
+ ssim_scores = []
+ for frame1, frame2 in zip(g_frames, r_frames):
+ # 计算两帧之间的SSIM
+ score, _ = ssim(frame1, frame2, full=True)
+ ssim_scores.append(score)
+
+ return numpy.mean(ssim_scores)
+
+
+
+
+
+
+ def calculate_metric(self, dataobject):
+
+ pair_list = self.get_videopair(dataobject)
+
+ ssim_score_total = []
+
+ iter_data = (
+ tqdm(
+ pair_list,
+ total=len(pair_list),
+ desc=set_color("calculate for video quality ssim", "yellow")
+ )
+ if self.config['show_progress']
+ else pair_list
+ )
+ for pair in iter_data:
+ g_video = pair[0]
+ r_video = pair[1]
+ ssim_score = self.metric_info(g_video, r_video)
+ ssim_score_total.append(ssim_score)
+
+ return sum(ssim_score_total)/len(ssim_score_total)
diff --git a/talkingface-toolkit-main/talkingface/evaluator/register.py b/talkingface-toolkit-main/talkingface/evaluator/register.py
new file mode 100644
index 00000000..2ff33c3b
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/evaluator/register.py
@@ -0,0 +1,46 @@
+import inspect
+import sys
+
+def cluster_info(module_name):
+ """Collect information of all metrics, including:
+
+ - ``metric_need``: Information needed to calculate this metric, the combination of ``rec.items, rec.topk,
+ rec.meanrank, rec.score, data.num_items, data.num_users, data.count_items, data.count_users, data.label``.
+ - ``metric_type``: Whether the scores required by metric are grouped by user, range in ``EvaluatorType.RANKING``
+ and ``EvaluatorType.VALUE``.
+ - ``smaller``: Whether the smaller metric value represents better performance,
+ range in ``True`` and ``False``, default to ``False``.
+
+ Args:
+ module_name (str): the name of module ``recbole.evaluator.metrics``.
+
+ Returns:
+ dict: Three dictionaries containing the above information
+ and a dictionary matching metric names to metric classes.
+ """
+ smaller_m = []
+ m_dict, m_info, m_types = {}, {}, {}
+ metric_class = inspect.getmembers(
+ sys.modules[module_name],
+ lambda x: inspect.isclass(x) and x.__module__ == module_name,
+ )
+ for name, metric_cls in metric_class:
+ name = name.lower()
+ m_dict[name] = metric_cls
+ if hasattr(metric_cls, "metric_need"):
+ m_info[name] = metric_cls.metric_need
+ else:
+ raise AttributeError(f"Metric '{name}' has no attribute [metric_need].")
+ if hasattr(metric_cls, "metric_type"):
+ m_types[name] = metric_cls.metric_type
+ else:
+ raise AttributeError(f"Metric '{name}' has no attribute [metric_type].")
+ if metric_cls.smaller is True:
+ smaller_m.append(name)
+ return smaller_m, m_info, m_types, m_dict
+
+
+metric_module_name = "talkingface.evaluator.metrics"
+smaller_metrics, metric_information, metric_types, metrics_dict = cluster_info(
+ metric_module_name
+)
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/model/__init__.py b/talkingface-toolkit-main/talkingface/model/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/talkingface-toolkit-main/talkingface/model/abstract_speech.py b/talkingface-toolkit-main/talkingface/model/abstract_speech.py
new file mode 100644
index 00000000..49d6a858
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/abstract_speech.py
@@ -0,0 +1,75 @@
+from logging import getLogger
+
+import torch
+import torch.nn as nn
+import numpy as np
+from talkingface.utils import set_color
+
+class AbstractSpeech(nn.Module):
+ """Abstract class for talking face model."""
+
+ def __init__(self):
+ self.logger = getLogger()
+ super(AbstractSpeech, self).__init__()
+
+ def calculate_loss(self, interaction):
+ r"""Calculate the training loss for a batch data.
+
+ Args:
+ interaction (Interaction): Interaction class of the batch.
+
+ Returns:
+ dict: {"loss": loss, "xxx": xxx}
+ 返回是一个字典,loss 这个键必须有,它代表了加权之后的总loss。
+ 因为有时总loss可能由多个部分组成。xxx代表其它各部分loss
+ """
+ raise NotImplementedError
+
+ def predict(self, interaction):
+ r"""Predict the scores between users and items.
+
+ Args:
+ interaction (Interaction): Interaction class of the batch.
+
+ Returns:
+ video/image numpy/tensor
+ """
+ raise NotImplementedError
+
+ def generate_batch():
+
+ """
+ 根据划分的test_filelist 批量生成数据。
+
+ Returns: dict: {"generated_audio": [generated_audio], "real_audio": [real_audio] }
+ 必须是一个字典数据, 且字典的键一个时generated_audio, 一个是real_audio,值都是列表,
+ 分别对应生成的音频和真实的音频。且两个列表的长度应该相同。
+ 即每个生成音频都有对应的真实音频(或近似对应的音频)。
+ """
+
+ raise NotImplementedError
+
+ def other_parameter(self):
+ if hasattr(self, "other_parameter_name"):
+ return {key: getattr(self, key) for key in self.other_parameter_name}
+ return dict()
+
+ def load_other_parameter(self, para):
+ if para is None:
+ return
+ for key, value in para.items():
+ setattr(self, key, value)
+
+ def __str__(self):
+ """
+ Model prints with number of trainable parameters
+ """
+ model_parameters = filter(lambda p: p.requires_grad, self.parameters())
+ params = sum([np.prod(p.size()) for p in model_parameters])
+ return (
+ super().__str__()
+ + set_color("\nTrainable parameters", "blue")
+ + f": {params}"
+ )
+
+
diff --git a/talkingface-toolkit-main/talkingface/model/abstract_talkingface.py b/talkingface-toolkit-main/talkingface/model/abstract_talkingface.py
new file mode 100644
index 00000000..14810e0d
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/abstract_talkingface.py
@@ -0,0 +1,75 @@
+from logging import getLogger
+
+import torch
+import torch.nn as nn
+import numpy as np
+from talkingface.utils import set_color
+
+class AbstractTalkingFace(nn.Module):
+ """Abstract class for talking face model."""
+
+ def __init__(self):
+ self.logger = getLogger()
+ super(AbstractTalkingFace, self).__init__()
+
+ def calculate_loss(self, interaction):
+ r"""Calculate the training loss for a batch data.
+
+ Args:
+ interaction (Interaction): Interaction class of the batch.
+
+ Returns:
+ dict: {"loss": loss, "xxx": xxx}
+ 返回是一个字典,loss 这个键必须有,它代表了加权之后的总loss。
+ 因为有时总loss可能由多个部分组成。xxx代表其它各部分loss
+ """
+ raise NotImplementedError
+
+ def predict(self, interaction):
+ r"""Predict the scores between users and items.
+
+ Args:
+ interaction (Interaction): Interaction class of the batch.
+
+ Returns:
+ video/image numpy/tensor
+ """
+ raise NotImplementedError
+
+ def generate_batch():
+
+ """
+ 根据划分的test_filelist 批量生成数据。
+
+ Returns: dict: {"generated_video": [generated_video], "real_video": [real_video] }
+ 必须是一个字典数据, 且字典的键一个时generated_video, 一个是real_video,值都是列表,
+ 分别对应生成的视频和真实的视频。且两个列表的长度应该相同。
+ 即每个生成视频都有对应的真实视频(或近似对应的视频)。
+ """
+
+ raise NotImplementedError
+
+ def other_parameter(self):
+ if hasattr(self, "other_parameter_name"):
+ return {key: getattr(self, key) for key in self.other_parameter_name}
+ return dict()
+
+ def load_other_parameter(self, para):
+ if para is None:
+ return
+ for key, value in para.items():
+ setattr(self, key, value)
+
+ def __str__(self):
+ """
+ Model prints with number of trainable parameters
+ """
+ model_parameters = filter(lambda p: p.requires_grad, self.parameters())
+ params = sum([np.prod(p.size()) for p in model_parameters])
+ return (
+ super().__str__()
+ + set_color("\nTrainable parameters", "blue")
+ + f": {params}"
+ )
+
+
diff --git a/talkingface-toolkit-main/talkingface/model/audio_driven_talkingface/__init__.py b/talkingface-toolkit-main/talkingface/model/audio_driven_talkingface/__init__.py
new file mode 100644
index 00000000..04a35f33
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/audio_driven_talkingface/__init__.py
@@ -0,0 +1 @@
+from talkingface.model.audio_driven_talkingface.wav2lip import Wav2Lip, SyncNet_color
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/model/image_driven_talkingface/__init__.py b/talkingface-toolkit-main/talkingface/model/image_driven_talkingface/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/talkingface-toolkit-main/talkingface/model/layers.py b/talkingface-toolkit-main/talkingface/model/layers.py
new file mode 100644
index 00000000..ed83da00
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/layers.py
@@ -0,0 +1,44 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+class Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+ self.residual = residual
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ if self.residual:
+ out += x
+ return self.act(out)
+
+class nonorm_Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ )
+ self.act = nn.LeakyReLU(0.01, inplace=True)
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ return self.act(out)
+
+class Conv2dTranspose(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ return self.act(out)
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/__init__.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/__init__.py
new file mode 100644
index 00000000..761f9524
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/__init__.py
@@ -0,0 +1,142 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
+ -- : unpack data from dataset and apply preprocessing.
+ -- : produce intermediate results.
+ -- : calculate loss, gradients, and update network weights.
+ -- : (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
+ -- self.model_names (str list): define networks used in our training.
+ -- self.visual_names (str list): specify the images that you want to display and save.
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+import os
+import importlib
+import numpy as np
+import torch
+import torch.nn as nn
+from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.base_model import BaseModel
+
+
+
+def find_model_using_name(model_name):
+ """Import the module "models/[model_name]_model.py".
+
+ In the file, the class called DatasetNameModel() will
+ be instantiated. It has to be a subclass of BaseModel,
+ and it is case-insensitive.
+ """
+ model_filename = "talkingface.model.audio_driven_talkingface.LiveSpeechPortraits." + model_name + "_model"
+ modellib = importlib.import_module(model_filename)
+ model = None
+ target_model_name = model_name.replace('_', '') + 'model'
+ for name, cls in modellib.__dict__.items():
+ if name.lower() == target_model_name.lower() \
+ and issubclass(cls, BaseModel):
+ model = cls
+
+ if model is None:
+ print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+ exit()
+
+ return model
+
+
+def get_option_setter(model_name):
+ """Return the static method of the model class."""
+ model_class = find_model_using_name(model_name)
+ return model_class.modify_commandline_options
+
+
+def create_model(opt):
+ """Create a model given the option.
+
+ This function warps the class CustomDatasetDataLoader.
+ This is the main interface between this package and 'train.py'/'test.py'
+
+ Example:
+# >>> from models import create_model
+# >>> model = create_model(opt)
+ """
+ model = find_model_using_name(opt.model_name)
+ instance = model(opt)
+ print("model [%s] was created" % type(instance).__name__)
+ return instance
+
+
+def save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, modelG, modelD, end_of_epoch=False):
+ if not end_of_epoch:
+ if total_steps % opt.save_latest_freq == 0:
+ visualizer.vis_print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
+ modelG.module.save('latest')
+ modelD.module.save('latest')
+ np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
+ else:
+ if epoch % opt.save_epoch_freq == 0:
+ visualizer.vis_print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
+ modelG.module.save('latest')
+ modelD.module.save('latest')
+ modelG.module.save(epoch)
+ modelD.module.save(epoch)
+ np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
+
+
+def update_models(opt, epoch, modelG, modelD, dataset_warp):
+ ### linearly decay learning rate after certain iterations
+ if epoch > opt.niter:
+ modelG.module.update_learning_rate(epoch, 'G')
+ modelD.module.update_learning_rate(epoch, 'D')
+
+ ### gradually grow training sequence length
+ if (epoch % opt.niter_step) == 0:
+ dataset_warp.dataset.update_training_batch(epoch//opt.niter_step)
+# modelG.module.update_training_batch(epoch//opt.niter_step)
+
+ ### finetune all scales
+ if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
+ modelG.module.update_fixed_params()
+
+
+class myModel(nn.Module):
+ def __init__(self, opt, model):
+ super(myModel, self).__init__()
+ self.opt = opt
+ self.module = model
+ self.model = nn.DataParallel(model, device_ids=opt.gpu_ids)
+ self.bs_per_gpu = int(np.ceil(float(opt.batch_size) / len(opt.gpu_ids))) # batch size for each GPU
+ self.pad_bs = self.bs_per_gpu * len(opt.gpu_ids) - opt.batch_size
+
+ def forward(self, *inputs, **kwargs):
+ inputs = self.add_dummy_to_tensor(inputs, self.pad_bs)
+ outputs = self.model(*inputs, **kwargs, dummy_bs=self.pad_bs)
+ if self.pad_bs == self.bs_per_gpu: # gpu 0 does 0 batch but still returns 1 batch
+ return self.remove_dummy_from_tensor(outputs, 1)
+ return outputs
+
+ def add_dummy_to_tensor(self, tensors, add_size=0):
+ if add_size == 0 or tensors is None: return tensors
+ if type(tensors) == list or type(tensors) == tuple:
+ return [self.add_dummy_to_tensor(tensor, add_size) for tensor in tensors]
+
+ if isinstance(tensors, torch.Tensor):
+ dummy = torch.zeros_like(tensors)[:add_size]
+ tensors = torch.cat([dummy, tensors])
+ return tensors
+
+ def remove_dummy_from_tensor(self, tensors, remove_size=0):
+ if remove_size == 0 or tensors is None: return tensors
+ if type(tensors) == list or type(tensors) == tuple:
+ return [self.remove_dummy_from_tensor(tensor, remove_size) for tensor in tensors]
+
+ if isinstance(tensors, torch.Tensor):
+ tensors = tensors[remove_size:]
+ return tensors
+
+
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_features.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_features.py
new file mode 100644
index 00000000..78bba1ec
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_features.py
@@ -0,0 +1,275 @@
+"""
+ DeepSpeech features processing routines.
+ NB: Based on VOCA code. See the corresponding license restrictions.
+"""
+
+__all__ = ['conv_audios_to_deepspeech']
+
+import numpy as np
+import warnings
+import resampy
+from scipy.io import wavfile
+from python_speech_features import mfcc
+import tensorflow as tf
+
+
+def conv_audios_to_deepspeech(audios,
+ out_files,
+ num_frames_info,
+ deepspeech_pb_path,
+ audio_window_size=1,
+ audio_window_stride=1):
+ """
+ Convert list of audio files into files with DeepSpeech features.
+
+ Parameters
+ ----------
+ audios : list of str or list of None
+ Paths to input audio files.
+ out_files : list of str
+ Paths to output files with DeepSpeech features.
+ num_frames_info : list of int
+ List of numbers of frames.
+ deepspeech_pb_path : str
+ Path to DeepSpeech 0.1.0 frozen model.
+ audio_window_size : int, default 16
+ Audio window size.
+ audio_window_stride : int, default 1
+ Audio window stride.
+ """
+ # deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
+ graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net(
+ deepspeech_pb_path)
+
+ with tf.compat.v1.Session(graph=graph) as sess:
+ for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info):
+ print(audio_file_path)
+ print(out_file_path)
+ audio_sample_rate, audio = wavfile.read(audio_file_path)
+ if audio.ndim != 1:
+ warnings.warn(
+ "Audio has multiple channels, the first channel is used")
+ audio = audio[:, 0]
+ ds_features = pure_conv_audio_to_deepspeech(
+ audio=audio,
+ audio_sample_rate=audio_sample_rate,
+ audio_window_size=audio_window_size,
+ audio_window_stride=audio_window_stride,
+ num_frames=num_frames,
+ net_fn=lambda x: sess.run(
+ logits_ph,
+ feed_dict={
+ input_node_ph: x[np.newaxis, ...],
+ input_lengths_ph: [x.shape[0]]}))
+
+ net_output = ds_features.reshape(-1, 29)
+ win_size = 16
+ zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
+ net_output = np.concatenate(
+ (zero_pad, net_output, zero_pad), axis=0)
+ windows = []
+ for window_index in range(0, net_output.shape[0] - win_size, 2):
+ windows.append(
+ net_output[window_index:window_index + win_size])
+ print(np.array(windows).shape)
+ np.save(out_file_path, np.array(windows))
+
+
+def prepare_deepspeech_net(deepspeech_pb_path):
+ """
+ Load and prepare DeepSpeech network.
+
+ Parameters
+ ----------
+ deepspeech_pb_path : str
+ Path to DeepSpeech 0.1.0 frozen model.
+
+ Returns
+ -------
+ graph : obj
+ ThensorFlow graph.
+ logits_ph : obj
+ ThensorFlow placeholder for `logits`.
+ input_node_ph : obj
+ ThensorFlow placeholder for `input_node`.
+ input_lengths_ph : obj
+ ThensorFlow placeholder for `input_lengths`.
+ """
+ # Load graph and place_holders:
+ with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f:
+ graph_def = tf.compat.v1.GraphDef()
+ graph_def.ParseFromString(f.read())
+
+ graph = tf.compat.v1.get_default_graph()
+ tf.import_graph_def(graph_def, name="deepspeech")
+ logits_ph = graph.get_tensor_by_name("deepspeech/logits:0")
+ input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0")
+ input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0")
+
+ return graph, logits_ph, input_node_ph, input_lengths_ph
+
+
+def pure_conv_audio_to_deepspeech(audio,
+ audio_sample_rate,
+ audio_window_size,
+ audio_window_stride,
+ num_frames,
+ net_fn):
+ """
+ Core routine for converting audion into DeepSpeech features.
+
+ Parameters
+ ----------
+ audio : np.array
+ Audio data.
+ audio_sample_rate : int
+ Audio sample rate.
+ audio_window_size : int
+ Audio window size.
+ audio_window_stride : int
+ Audio window stride.
+ num_frames : int or None
+ Numbers of frames.
+ net_fn : func
+ Function for DeepSpeech model call.
+
+ Returns
+ -------
+ np.array
+ DeepSpeech features.
+ """
+ target_sample_rate = 16000
+ if audio_sample_rate != target_sample_rate:
+ resampled_audio = resampy.resample(
+ x=audio.astype(np.float),
+ sr_orig=audio_sample_rate,
+ sr_new=target_sample_rate)
+ else:
+ resampled_audio = audio.astype(np.float)
+ input_vector = conv_audio_to_deepspeech_input_vector(
+ audio=resampled_audio.astype(np.int16),
+ sample_rate=target_sample_rate,
+ num_cepstrum=26,
+ num_context=9)
+
+ network_output = net_fn(input_vector)
+ # print(network_output.shape)
+
+ deepspeech_fps = 50
+ video_fps = 50 # Change this option if video fps is different
+ audio_len_s = float(audio.shape[0]) / audio_sample_rate
+ if num_frames is None:
+ num_frames = int(round(audio_len_s * video_fps))
+ else:
+ video_fps = num_frames / audio_len_s
+ network_output = interpolate_features(
+ features=network_output[:, 0],
+ input_rate=deepspeech_fps,
+ output_rate=video_fps,
+ output_len=num_frames)
+
+ # Make windows:
+ zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1]))
+ network_output = np.concatenate(
+ (zero_pad, network_output, zero_pad), axis=0)
+ windows = []
+ for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride):
+ windows.append(
+ network_output[window_index:window_index + audio_window_size])
+
+ return np.array(windows)
+
+
+def conv_audio_to_deepspeech_input_vector(audio,
+ sample_rate,
+ num_cepstrum,
+ num_context):
+ """
+ Convert audio raw data into DeepSpeech input vector.
+
+ Parameters
+ ----------
+ audio : np.array
+ Audio data.
+ audio_sample_rate : int
+ Audio sample rate.
+ num_cepstrum : int
+ Number of cepstrum.
+ num_context : int
+ Number of context.
+
+ Returns
+ -------
+ np.array
+ DeepSpeech input vector.
+ """
+ # Get mfcc coefficients:
+ features = mfcc(
+ signal=audio,
+ samplerate=sample_rate,
+ numcep=num_cepstrum)
+
+ # We only keep every second feature (BiRNN stride = 2):
+ features = features[::2]
+
+ # One stride per time step in the input:
+ num_strides = len(features)
+
+ # Add empty initial and final contexts:
+ empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype)
+ features = np.concatenate((empty_context, features, empty_context))
+
+ # Create a view into the array with overlapping strides of size
+ # numcontext (past) + 1 (present) + numcontext (future):
+ window_size = 2 * num_context + 1
+ train_inputs = np.lib.stride_tricks.as_strided(
+ features,
+ shape=(num_strides, window_size, num_cepstrum),
+ strides=(features.strides[0],
+ features.strides[0], features.strides[1]),
+ writeable=False)
+
+ # Flatten the second and third dimensions:
+ train_inputs = np.reshape(train_inputs, [num_strides, -1])
+
+ train_inputs = np.copy(train_inputs)
+ train_inputs = (train_inputs - np.mean(train_inputs)) / \
+ np.std(train_inputs)
+
+ return train_inputs
+
+
+def interpolate_features(features,
+ input_rate,
+ output_rate,
+ output_len):
+ """
+ Interpolate DeepSpeech features.
+
+ Parameters
+ ----------
+ features : np.array
+ DeepSpeech features.
+ input_rate : int
+ input rate (FPS).
+ output_rate : int
+ Output rate (FPS).
+ output_len : int
+ Output data length.
+
+ Returns
+ -------
+ np.array
+ Interpolated data.
+ """
+ input_len = features.shape[0]
+ num_features = features.shape[1]
+ input_timestamps = np.arange(input_len) / float(input_rate)
+ output_timestamps = np.arange(output_len) / float(output_rate)
+ output_features = np.zeros((output_len, num_features))
+ for feature_idx in range(num_features):
+ output_features[:, feature_idx] = np.interp(
+ x=output_timestamps,
+ xp=input_timestamps,
+ fp=features[:, feature_idx])
+ return output_features
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_store.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_store.py
new file mode 100644
index 00000000..5595a4d5
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_store.py
@@ -0,0 +1,172 @@
+"""
+ Routines for loading DeepSpeech model.
+"""
+
+__all__ = ['get_deepspeech_model_file']
+
+import os
+import zipfile
+import logging
+import hashlib
+
+
+deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features'
+
+
+def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")):
+ """
+ Return location for the pretrained on local file system. This function will download from online model zoo when
+ model cannot be found or has mismatch. The root directory will be created if it doesn't exist.
+
+ Parameters
+ ----------
+ local_model_store_dir_path : str, default $TENSORFLOW_HOME/models
+ Location for keeping the model parameters.
+
+ Returns
+ -------
+ file_path
+ Path to the requested pretrained model file.
+ """
+ sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e"
+ file_name = "deepspeech-0_1_0-b90017e8.pb"
+ local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path)
+ file_path = os.path.join(local_model_store_dir_path, file_name)
+ if os.path.exists(file_path):
+ if _check_sha1(file_path, sha1_hash):
+ return file_path
+ else:
+ logging.warning("Mismatch in the content of model file detected. Downloading again.")
+ else:
+ logging.info("Model file not found. Downloading to {}.".format(file_path))
+
+ if not os.path.exists(local_model_store_dir_path):
+ os.makedirs(local_model_store_dir_path)
+
+ zip_file_path = file_path + ".zip"
+ _download(
+ url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format(
+ repo_url=deepspeech_features_repo_url,
+ repo_release_tag="v0.0.1",
+ file_name=file_name),
+ path=zip_file_path,
+ overwrite=True)
+ with zipfile.ZipFile(zip_file_path) as zf:
+ zf.extractall(local_model_store_dir_path)
+ os.remove(zip_file_path)
+
+ if _check_sha1(file_path, sha1_hash):
+ return file_path
+ else:
+ raise ValueError("Downloaded file has different hash. Please try again.")
+
+
+def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
+ """
+ Download an given URL
+
+ Parameters
+ ----------
+ url : str
+ URL to download
+ path : str, optional
+ Destination path to store downloaded file. By default stores to the
+ current directory with same name as in url.
+ overwrite : bool, optional
+ Whether to overwrite destination file if already exists.
+ sha1_hash : str, optional
+ Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified
+ but doesn't match.
+ retries : integer, default 5
+ The number of times to attempt the download in case of failure or non 200 return codes
+ verify_ssl : bool, default True
+ Verify SSL certificates.
+
+ Returns
+ -------
+ str
+ The file path of the downloaded file.
+ """
+ import warnings
+ try:
+ import requests
+ except ImportError:
+ class requests_failed_to_import(object):
+ pass
+ requests = requests_failed_to_import
+
+ if path is None:
+ fname = url.split("/")[-1]
+ # Empty filenames are invalid
+ assert fname, "Can't construct file-name from this URL. Please set the `path` option manually."
+ else:
+ path = os.path.expanduser(path)
+ if os.path.isdir(path):
+ fname = os.path.join(path, url.split("/")[-1])
+ else:
+ fname = path
+ assert retries >= 0, "Number of retries should be at least 0"
+
+ if not verify_ssl:
+ warnings.warn(
+ "Unverified HTTPS request is being made (verify_ssl=False). "
+ "Adding certificate verification is strongly advised.")
+
+ if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)):
+ dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+ while retries + 1 > 0:
+ # Disable pyling too broad Exception
+ # pylint: disable=W0703
+ try:
+ print("Downloading {} from {}...".format(fname, url))
+ r = requests.get(url, stream=True, verify=verify_ssl)
+ if r.status_code != 200:
+ raise RuntimeError("Failed downloading url {}".format(url))
+ with open(fname, "wb") as f:
+ for chunk in r.iter_content(chunk_size=1024):
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ if sha1_hash and not _check_sha1(fname, sha1_hash):
+ raise UserWarning("File {} is downloaded but the content hash does not match."
+ " The repo may be outdated or download may be incomplete. "
+ "If the `repo_url` is overridden, consider switching to "
+ "the default repo.".format(fname))
+ break
+ except Exception as e:
+ retries -= 1
+ if retries <= 0:
+ raise e
+ else:
+ print("download failed, retrying, {} attempt{} left"
+ .format(retries, "s" if retries > 1 else ""))
+
+ return fname
+
+
+def _check_sha1(filename, sha1_hash):
+ """
+ Check whether the sha1 hash of the file content matches the expected hash.
+
+ Parameters
+ ----------
+ filename : str
+ Path to the file.
+ sha1_hash : str
+ Expected sha1 hash in hexadecimal digits.
+
+ Returns
+ -------
+ bool
+ Whether the file content matches the expected hash.
+ """
+ sha1 = hashlib.sha1()
+ with open(filename, "rb") as f:
+ while True:
+ data = f.read(1048576)
+ if not data:
+ break
+ sha1.update(data)
+
+ return sha1.hexdigest() == sha1_hash
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_ds_features.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_ds_features.py
new file mode 100644
index 00000000..34c1fe14
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_ds_features.py
@@ -0,0 +1,129 @@
+"""
+ Script for extracting DeepSpeech features from audio file.
+"""
+
+import os
+import argparse
+import numpy as np
+import pandas as pd
+from deepspeech_store import get_deepspeech_model_file
+from deepspeech_features import conv_audios_to_deepspeech
+
+
+def parse_args():
+ """
+ Create python script parameters.
+ Returns
+ -------
+ ArgumentParser
+ Resulted args.
+ """
+ parser = argparse.ArgumentParser(
+ description="Extract DeepSpeech features from audio file",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ "--input",
+ type=str,
+ required=True,
+ help="path to input audio file or directory")
+ parser.add_argument(
+ "--output",
+ type=str,
+ help="path to output file with DeepSpeech features")
+ parser.add_argument(
+ "--deepspeech",
+ type=str,
+ help="path to DeepSpeech 0.1.0 frozen model")
+ parser.add_argument(
+ "--metainfo",
+ type=str,
+ help="path to file with meta-information")
+
+ args = parser.parse_args()
+ return args
+
+
+def extract_features(in_audios,
+ out_files,
+ deepspeech_pb_path,
+ metainfo_file_path=None):
+ """
+ Real extract audio from video file.
+ Parameters
+ ----------
+ in_audios : list of str
+ Paths to input audio files.
+ out_files : list of str
+ Paths to output files with DeepSpeech features.
+ deepspeech_pb_path : str
+ Path to DeepSpeech 0.1.0 frozen model.
+ metainfo_file_path : str, default None
+ Path to file with meta-information.
+ """
+ #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
+ if metainfo_file_path is None:
+ num_frames_info = [None] * len(in_audios)
+ else:
+ train_df = pd.read_csv(
+ metainfo_file_path,
+ sep="\t",
+ index_col=False,
+ dtype={"Id": np.int, "File": np.unicode, "Count": np.int})
+ num_frames_info = train_df["Count"].values
+ assert (len(num_frames_info) == len(in_audios))
+
+ for i, in_audio in enumerate(in_audios):
+ if not out_files[i]:
+ file_stem, _ = os.path.splitext(in_audio)
+ out_files[i] = file_stem + ".npy"
+ #print(out_files[i])
+ conv_audios_to_deepspeech(
+ audios=in_audios,
+ out_files=out_files,
+ num_frames_info=num_frames_info,
+ deepspeech_pb_path=deepspeech_pb_path)
+
+
+def main():
+ """
+ Main body of script.
+ """
+ args = parse_args()
+ in_audio = os.path.expanduser(args.input)
+ if not os.path.exists(in_audio):
+ raise Exception("Input file/directory doesn't exist: {}".format(in_audio))
+ deepspeech_pb_path = args.deepspeech
+ #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm"
+ if deepspeech_pb_path is None:
+ deepspeech_pb_path = ""
+ if deepspeech_pb_path:
+ deepspeech_pb_path = os.path.expanduser(args.deepspeech)
+ if not os.path.exists(deepspeech_pb_path):
+ deepspeech_pb_path = get_deepspeech_model_file()
+ if os.path.isfile(in_audio):
+ extract_features(
+ in_audios=[in_audio],
+ out_files=[args.output],
+ deepspeech_pb_path=deepspeech_pb_path,
+ metainfo_file_path=args.metainfo)
+ else:
+ audio_file_paths = []
+ for file_name in os.listdir(in_audio):
+ if not os.path.isfile(os.path.join(in_audio, file_name)):
+ continue
+ _, file_ext = os.path.splitext(file_name)
+ if file_ext.lower() == ".wav":
+ audio_file_path = os.path.join(in_audio, file_name)
+ audio_file_paths.append(audio_file_path)
+ audio_file_paths = sorted(audio_file_paths)
+ out_file_paths = [""] * len(audio_file_paths)
+ extract_features(
+ in_audios=audio_file_paths,
+ out_files=out_file_paths,
+ deepspeech_pb_path=deepspeech_pb_path,
+ metainfo_file_path=args.metainfo)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_wav.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_wav.py
new file mode 100644
index 00000000..8458c5f2
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_wav.py
@@ -0,0 +1,87 @@
+"""
+ Script for extracting audio (16-bit, mono, 22000 Hz) from video file.
+"""
+
+import os
+import argparse
+import subprocess
+
+
+def parse_args():
+ """
+ Create python script parameters.
+
+ Returns
+ -------
+ ArgumentParser
+ Resulted args.
+ """
+ parser = argparse.ArgumentParser(
+ description="Extract audio from video file",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+ parser.add_argument(
+ "--in-video",
+ type=str,
+ required=True,
+ help="path to input video file or directory")
+ parser.add_argument(
+ "--out-audio",
+ type=str,
+ help="path to output audio file")
+
+ args = parser.parse_args()
+ return args
+
+
+def extract_audio(in_video,
+ out_audio):
+ """
+ Real extract audio from video file.
+
+ Parameters
+ ----------
+ in_video : str
+ Path to input video file.
+ out_audio : str
+ Path to output audio file.
+ """
+ if not out_audio:
+ file_stem, _ = os.path.splitext(in_video)
+ out_audio = file_stem + ".wav"
+ # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}"
+ # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
+ # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}"
+ command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}"
+ subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True)
+
+
+def main():
+ """
+ Main body of script.
+ """
+ args = parse_args()
+ in_video = os.path.expanduser(args.in_video)
+ if not os.path.exists(in_video):
+ raise Exception("Input file/directory doesn't exist: {}".format(in_video))
+ if os.path.isfile(in_video):
+ extract_audio(
+ in_video=in_video,
+ out_audio=args.out_audio)
+ else:
+ video_file_paths = []
+ for file_name in os.listdir(in_video):
+ if not os.path.isfile(os.path.join(in_video, file_name)):
+ continue
+ _, file_ext = os.path.splitext(file_name)
+ if file_ext.lower() in (".mp4", ".mkv", ".avi"):
+ video_file_path = os.path.join(in_video, file_name)
+ video_file_paths.append(video_file_path)
+ video_file_paths = sorted(video_file_paths)
+ for video_file_path in video_file_paths:
+ extract_audio(
+ in_video=video_file_path,
+ out_audio="")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/fea_win.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/fea_win.py
new file mode 100644
index 00000000..df9e27b4
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/fea_win.py
@@ -0,0 +1,11 @@
+import numpy as np
+
+net_output = np.load('french.ds.npy').reshape(-1, 29)
+win_size = 16
+zero_pad = np.zeros((int(win_size / 2), net_output.shape[1]))
+net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0)
+windows = []
+for window_index in range(0, net_output.shape[0] - win_size, 2):
+ windows.append(net_output[window_index:window_index + win_size])
+print(np.array(windows).shape)
+np.save('aud_french.npy', np.array(windows))
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/load_audface.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/load_audface.py
new file mode 100644
index 00000000..b6d59727
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/load_audface.py
@@ -0,0 +1,114 @@
+import os
+import torch
+import numpy as np
+import imageio
+import json
+import torch.nn.functional as F
+import cv2
+
+
+def load_audface_data(basedir, testskip=1, test_file=None, aud_file=None, test_size=-1):
+ if test_file is not None:
+ with open(os.path.join(basedir, test_file)) as fp:
+ meta = json.load(fp)
+ poses = []
+ auds = []
+ aud_features = np.load(os.path.join(basedir, aud_file))
+ cur_id = 0
+ for frame in meta['frames'][::testskip]:
+ poses.append(np.array(frame['transform_matrix']))
+ aud_id = cur_id
+ auds.append(aud_features[aud_id])
+ cur_id = cur_id + 1
+ if cur_id == aud_features.shape[0] or cur_id == test_size:
+ break
+ poses = np.array(poses).astype(np.float32)
+ auds = np.array(auds).astype(np.float32)
+ bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg'))
+ H, W = bc_img.shape[0], bc_img.shape[1]
+ focal, cx, cy = float(meta['focal_len']), float(
+ meta['cx']), float(meta['cy'])
+ return poses, auds, bc_img, [H, W, focal, cx, cy]
+
+ splits = ['train', 'val']
+ metas = {}
+ for s in splits:
+ with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
+ metas[s] = json.load(fp)
+ all_com_imgs = []
+ all_poses = []
+ all_auds = []
+ all_sample_rects = []
+ aud_features = np.load(os.path.join(basedir, 'aud.npy'))
+ counts = [0]
+ for s in splits:
+ meta = metas[s]
+ com_imgs = []
+ poses = []
+ auds = []
+ sample_rects = []
+ if s == 'train' or testskip == 0:
+ skip = 1
+ else:
+ skip = testskip
+
+ for frame in meta['frames'][::skip]:
+ filename = os.path.join(basedir, 'com_imgs',
+ str(frame['img_id']) + '.jpg')
+ com_imgs.append(filename)
+ poses.append(np.array(frame['transform_matrix']))
+ auds.append(
+ aud_features[min(frame['aud_id'], aud_features.shape[0]-1)])
+ sample_rects.append(np.array(frame['face_rect'], dtype=np.int32))
+ com_imgs = np.array(com_imgs)
+ poses = np.array(poses).astype(np.float32)
+ auds = np.array(auds).astype(np.float32)
+ counts.append(counts[-1] + com_imgs.shape[0])
+ all_com_imgs.append(com_imgs)
+ all_poses.append(poses)
+ all_auds.append(auds)
+ all_sample_rects.append(sample_rects)
+ i_split = [np.arange(counts[i], counts[i+1]) for i in range(len(splits))]
+ com_imgs = np.concatenate(all_com_imgs, 0)
+ poses = np.concatenate(all_poses, 0)
+ auds = np.concatenate(all_auds, 0)
+ sample_rects = np.concatenate(all_sample_rects, 0)
+
+ bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg'))
+
+ H, W = bc_img.shape[:2]
+ focal, cx, cy = float(meta['focal_len']), float(
+ meta['cx']), float(meta['cy'])
+
+ return com_imgs, poses, auds, bc_img, [H, W, focal, cx, cy], \
+ sample_rects, i_split
+
+
+def load_test_data(basedir, aud_file, test_pose_file='transforms_train.json',
+ testskip=1, test_size=-1, aud_start=0):
+ with open(os.path.join(basedir, test_pose_file)) as fp:
+ meta = json.load(fp)
+ poses = []
+ auds = []
+ aud_features = np.load(aud_file)
+ aud_ids = []
+ cur_id = 0
+ for frame in meta['frames'][::testskip]:
+ poses.append(np.array(frame['transform_matrix']))
+ auds.append(
+ aud_features[min(aud_start+cur_id, aud_features.shape[0]-1)])
+ aud_ids.append(aud_start+cur_id)
+ cur_id = cur_id + 1
+ if cur_id == test_size or cur_id == aud_features.shape[0]:
+ break
+ poses = np.array(poses).astype(np.float32)
+ auds = np.array(auds).astype(np.float32)
+ bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg'))
+ H, W = bc_img.shape[0], bc_img.shape[1]
+ focal, cx, cy = float(meta['focal_len']), float(
+ meta['cx']), float(meta['cy'])
+
+ with open(os.path.join(basedir, 'transforms_train.json')) as fp:
+ meta_torso = json.load(fp)
+ torso_pose = np.array(meta_torso['frames'][0]['transform_matrix'])
+ return poses, auds, bc_img, [H, W, focal, cx, cy], aud_ids, torso_pose
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf.py
new file mode 100644
index 00000000..b912f6c6
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf.py
@@ -0,0 +1,1136 @@
+from load_audface import load_audface_data, load_test_data
+import os
+import sys
+import numpy as np
+import imageio
+import json
+import random
+import time
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm, trange
+from natsort import natsorted
+import cv2
+
+from run_nerf_helpers import *
+
+device = torch.device('cuda', 0)
+device_torso = torch.device('cuda', 0)
+np.random.seed(0)
+DEBUG = False
+
+
+def rot_to_euler(R):
+ batch_size, _, _ = R.shape
+ e = torch.ones((batch_size, 3)).cuda()
+
+ R00 = R[:, 0, 0]
+ R01 = R[:, 0, 1]
+ R02 = R[:, 0, 2]
+ R10 = R[:, 1, 0]
+ R11 = R[:, 1, 1]
+ R12 = R[:, 1, 2]
+ R20 = R[:, 2, 0]
+ R21 = R[:, 2, 1]
+ R22 = R[:, 2, 2]
+ e[:, 2] = torch.atan2(R00, -R01)
+ e[:, 1] = torch.asin(-R02)
+ e[:, 0] = torch.atan2(R22, R12)
+ return e
+
+
+def pose_to_euler_trans(poses):
+ e = rot_to_euler(poses)
+ t = poses[:, :3, 3]
+ return torch.cat((e, t), dim=1)
+
+
+def euler2rot(euler_angle):
+ batch_size = euler_angle.shape[0]
+ theta = euler_angle[:, 0].reshape(-1, 1, 1)
+ phi = euler_angle[:, 1].reshape(-1, 1, 1)
+ psi = euler_angle[:, 2].reshape(-1, 1, 1)
+ one = torch.ones((batch_size, 1, 1), dtype=torch.float32,
+ device=euler_angle.device)
+ zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32,
+ device=euler_angle.device)
+ rot_x = torch.cat((
+ torch.cat((one, zero, zero), 1),
+ torch.cat((zero, theta.cos(), theta.sin()), 1),
+ torch.cat((zero, -theta.sin(), theta.cos()), 1),
+ ), 2)
+ rot_y = torch.cat((
+ torch.cat((phi.cos(), zero, -phi.sin()), 1),
+ torch.cat((zero, one, zero), 1),
+ torch.cat((phi.sin(), zero, phi.cos()), 1),
+ ), 2)
+ rot_z = torch.cat((
+ torch.cat((psi.cos(), -psi.sin(), zero), 1),
+ torch.cat((psi.sin(), psi.cos(), zero), 1),
+ torch.cat((zero, zero, one), 1)
+ ), 2)
+ return torch.bmm(rot_x, torch.bmm(rot_y, rot_z))
+
+
+def batchify(fn, chunk):
+ """Constructs a version of 'fn' that applies to smaller batches.
+ """
+ if chunk is None:
+ return fn
+
+ def ret(inputs):
+ return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
+ return ret
+
+
+def run_network(inputs, viewdirs, aud_para, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
+ """Prepares inputs and applies network 'fn'.
+ """
+ inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
+ embedded = embed_fn(inputs_flat)
+ aud = aud_para.unsqueeze(0).expand(inputs_flat.shape[0], -1)
+ embedded = torch.cat((embedded, aud), -1)
+ if viewdirs is not None:
+ input_dirs = viewdirs[:, None].expand(inputs.shape)
+ input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
+ embedded_dirs = embeddirs_fn(input_dirs_flat)
+ embedded = torch.cat([embedded, embedded_dirs], -1)
+
+ outputs_flat = batchify(fn, netchunk)(embedded)
+ outputs = torch.reshape(outputs_flat, list(
+ inputs.shape[:-1]) + [outputs_flat.shape[-1]])
+ return outputs
+
+
+def batchify_rays(rays_flat, bc_rgb, aud_para, chunk=1024*32, **kwargs):
+ """Render rays in smaller minibatches to avoid OOM.
+ """
+ all_ret = {}
+ for i in range(0, rays_flat.shape[0], chunk):
+ ret = render_rays(rays_flat[i:i+chunk], bc_rgb[i:i+chunk],
+ aud_para, **kwargs)
+ for k in ret:
+ if k not in all_ret:
+ all_ret[k] = []
+ all_ret[k].append(ret[k])
+
+ all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
+ return all_ret
+
+
+def render_dynamic_face(H, W, focal, cx, cy, chunk=1024*32, rays=None, bc_rgb=None, aud_para=None,
+ c2w=None, ndc=True, near=0., far=1.,
+ use_viewdirs=False, c2w_staticcam=None,
+ **kwargs):
+ if c2w is not None:
+ # special case to render full image
+ rays_o, rays_d = get_rays(H, W, focal, c2w, cx, cy, c2w.device)
+ bc_rgb = bc_rgb.reshape(-1, 3)
+ else:
+ # use provided ray batch
+ rays_o, rays_d = rays
+
+ if use_viewdirs:
+ # provide ray directions as input
+ viewdirs = rays_d
+ if c2w_staticcam is not None:
+ # special case to visualize effect of viewdirs
+ rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam, cx, cy)
+ viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
+ viewdirs = torch.reshape(viewdirs, [-1, 3]).float()
+
+ sh = rays_d.shape # [..., 3]
+ if ndc:
+ # for forward facing scenes
+ rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)
+
+ # Create ray batch
+ rays_o = torch.reshape(rays_o, [-1, 3]).float()
+ rays_d = torch.reshape(rays_d, [-1, 3]).float()
+
+ near, far = near * \
+ torch.ones_like(rays_d[..., :1]), far * \
+ torch.ones_like(rays_d[..., :1])
+ rays = torch.cat([rays_o, rays_d, near, far], -1)
+ if use_viewdirs:
+ rays = torch.cat([rays, viewdirs], -1)
+
+ # Render and reshape
+ all_ret = batchify_rays(rays, bc_rgb, aud_para, chunk, **kwargs)
+ for k in all_ret:
+ k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
+ all_ret[k] = torch.reshape(all_ret[k], k_sh)
+
+ k_extract = ['rgb_map', 'disp_map', 'acc_map', 'last_weight', 'rgb_map_fg']
+ ret_list = [all_ret[k] for k in k_extract]
+ ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract}
+ return ret_list + [ret_dict]
+
+
+def render(H, W, focal, cx, cy, chunk=1024*32, rays=None, c2w=None, ndc=True,
+ near=0., far=1.,
+ use_viewdirs=False, c2w_staticcam=None,
+ **kwargs):
+ """Render rays
+ Args:
+ H: int. Height of image in pixels.
+ W: int. Width of image in pixels.
+ focal: float. Focal length of pinhole camera.
+ chunk: int. Maximum number of rays to process simultaneously. Used to
+ control maximum memory usage. Does not affect final results.
+ rays: array of shape [2, batch_size, 3]. Ray origin and direction for
+ each example in batch.
+ c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
+ ndc: bool. If True, represent ray origin, direction in NDC coordinates.
+ near: float or array of shape [batch_size]. Nearest distance for a ray.
+ far: float or array of shape [batch_size]. Farthest distance for a ray.
+ use_viewdirs: bool. If True, use viewing direction of a point in space in model.
+ c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
+ camera while using other c2w argument for viewing directions.
+ Returns:
+ rgb_map: [batch_size, 3]. Predicted RGB values for rays.
+ disp_map: [batch_size]. Disparity map. Inverse of depth.
+ acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
+ extras: dict with everything returned by render_rays().
+ """
+ if c2w is not None:
+ # special case to render full image
+ rays_o, rays_d = get_rays(H, W, focal, c2w, cx, cy)
+ else:
+ # use provided ray batch
+ rays_o, rays_d = rays
+
+ if use_viewdirs:
+ # provide ray directions as input
+ viewdirs = rays_d
+ if c2w_staticcam is not None:
+ # special case to visualize effect of viewdirs
+ rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam, cx, cy)
+ viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
+ viewdirs = torch.reshape(viewdirs, [-1, 3]).float()
+
+ sh = rays_d.shape # [..., 3]
+ if ndc:
+ # for forward facing scenes
+ rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)
+
+ # Create ray batch
+ rays_o = torch.reshape(rays_o, [-1, 3]).float()
+ rays_d = torch.reshape(rays_d, [-1, 3]).float()
+
+ near, far = near * \
+ torch.ones_like(rays_d[..., :1]), far * \
+ torch.ones_like(rays_d[..., :1])
+ rays = torch.cat([rays_o, rays_d, near, far], -1)
+ if use_viewdirs:
+ rays = torch.cat([rays, viewdirs], -1)
+
+ # Render and reshape
+ all_ret = batchify_rays(rays, chunk, **kwargs)
+ for k in all_ret:
+ k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
+ all_ret[k] = torch.reshape(all_ret[k], k_sh)
+
+ k_extract = ['rgb_map', 'disp_map', 'acc_map']
+ ret_list = [all_ret[k] for k in k_extract]
+ ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract}
+ return ret_list + [ret_dict]
+
+
+def render_path(render_poses, aud_paras, bc_img, hwfcxy,
+ chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
+
+ H, W, focal, cx, cy = hwfcxy
+
+ if render_factor != 0:
+ # Render downsampled for speed
+ H = H//render_factor
+ W = W//render_factor
+ focal = focal/render_factor
+
+ rgbs = []
+ disps = []
+ last_weights = []
+ rgb_fgs = []
+
+ t = time.time()
+ for i, c2w in enumerate(tqdm(render_poses)):
+ print(i, time.time() - t)
+ t = time.time()
+ rgb, disp, acc, last_weight, rgb_fg, _ = render_dynamic_face(
+ H, W, focal, cx, cy, chunk=chunk, c2w=c2w[:3,
+ :4], aud_para=aud_paras[i], bc_rgb=bc_img,
+ **render_kwargs)
+ rgbs.append(rgb.cpu().numpy())
+ disps.append(disp.cpu().numpy())
+ last_weights.append(last_weight.cpu().numpy())
+ rgb_fgs.append(rgb_fg.cpu().numpy())
+ # if i == 0:
+ # print(rgb.shape, disp.shape)
+
+ """
+ if gt_imgs is not None and render_factor==0:
+ p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
+ print(p)
+ """
+
+ if savedir is not None:
+ rgb8 = to8b(rgbs[-1])
+ filename = os.path.join(savedir, '{:03d}.png'.format(i))
+ imageio.imwrite(filename, rgb8)
+
+ rgbs = np.stack(rgbs, 0)
+ disps = np.stack(disps, 0)
+ last_weights = np.stack(last_weights, 0)
+ rgb_fgs = np.stack(rgb_fgs, 0)
+
+ return rgbs, disps, last_weights, rgb_fgs
+
+
+def create_nerf(args, ext, dim_aud, device_spec=torch.device('cuda', 0), with_audatt=False):
+ """Instantiate NeRF's MLP model.
+ """
+ embed_fn, input_ch = get_embedder(
+ args.multires, args.i_embed, device=device_spec)
+
+ input_ch_views = 0
+ embeddirs_fn = None
+ if args.use_viewdirs:
+ embeddirs_fn, input_ch_views = get_embedder(
+ args.multires_views, args.i_embed, device=device_spec)
+ output_ch = 5 if args.N_importance > 0 else 4
+ skips = [4]
+ model = FaceNeRF(D=args.netdepth, W=args.netwidth,
+ input_ch=input_ch, dim_aud=dim_aud,
+ output_ch=output_ch, skips=skips,
+ input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device_spec)
+ grad_vars = list(model.parameters())
+
+ model_fine = None
+ if args.N_importance > 0:
+ model_fine = FaceNeRF(D=args.netdepth_fine, W=args.netwidth_fine,
+ input_ch=input_ch, dim_aud=dim_aud,
+ output_ch=output_ch, skips=skips,
+ input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device_spec)
+ grad_vars += list(model_fine.parameters())
+
+ def network_query_fn(inputs, viewdirs, aud_para, network_fn): \
+ return run_network(inputs, viewdirs, aud_para, network_fn,
+ embed_fn=embed_fn, embeddirs_fn=embeddirs_fn, netchunk=args.netchunk)
+
+ # Create optimizer
+ optimizer = torch.optim.Adam(
+ params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
+
+ start = 0
+ basedir = args.basedir
+ expname = args.expname
+
+ ##########################
+
+ # Load checkpoints
+ if args.ft_path is not None and args.ft_path != 'None':
+ ckpts = [args.ft_path]
+ else:
+ ckpts = [os.path.join(basedir, expname, f) for f in natsorted(
+ os.listdir(os.path.join(basedir, expname))) if ext in f]
+
+ print('Found ckpts', ckpts)
+ learned_codes_dict = None
+ AudNet_state = None
+ optimizer_aud_state = None
+ AudAttNet_state = None
+ if len(ckpts) > 0 and not args.no_reload:
+ ckpt_path = ckpts[-1]
+ print('Reloading from', ckpt_path)
+ ckpt = torch.load(ckpt_path, map_location=device)
+
+ start = ckpt['global_step']
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
+ AudNet_state = ckpt['network_audnet_state_dict']
+ optimizer_aud_state = ckpt['optimizer_aud_state_dict']
+ if with_audatt:
+ AudAttNet_state = ckpt['network_audattnet_state_dict']
+
+ # Load model
+ model.load_state_dict(ckpt['network_fn_state_dict'])
+ if model_fine is not None:
+ model_fine.load_state_dict(ckpt['network_fine_state_dict'])
+
+ ##########################
+
+ render_kwargs_train = {
+ 'network_query_fn': network_query_fn,
+ 'perturb': args.perturb,
+ 'N_importance': args.N_importance,
+ 'network_fine': model_fine,
+ 'N_samples': args.N_samples,
+ 'network_fn': model,
+ 'use_viewdirs': args.use_viewdirs,
+ 'white_bkgd': args.white_bkgd,
+ 'raw_noise_std': args.raw_noise_std,
+ }
+
+ # NDC only good for LLFF-style forward facing data
+ if args.dataset_type != 'llff' or args.no_ndc:
+ print('Not ndc!')
+ render_kwargs_train['ndc'] = False
+ render_kwargs_train['lindisp'] = args.lindisp
+
+ render_kwargs_test = {
+ k: render_kwargs_train[k] for k in render_kwargs_train}
+ render_kwargs_test['perturb'] = False
+ render_kwargs_test['raw_noise_std'] = 0.
+
+ return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, learned_codes_dict, \
+ AudNet_state, optimizer_aud_state, AudAttNet_state
+
+
+def raw2outputs(raw, z_vals, rays_d, bc_rgb, raw_noise_std=0, white_bkgd=False, pytest=False):
+ """Transforms model's predictions to semantically meaningful values.
+ Args:
+ raw: [num_rays, num_samples along ray, 4]. Prediction from model.
+ z_vals: [num_rays, num_samples along ray]. Integration time.
+ rays_d: [num_rays, 3]. Direction of each ray.
+ Returns:
+ rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
+ disp_map: [num_rays]. Disparity map. Inverse of depth map.
+ acc_map: [num_rays]. Sum of weights along each ray.
+ weights: [num_rays, num_samples]. Weights assigned to each sampled color.
+ depth_map: [num_rays]. Estimated distance to object.
+ """
+ def raw2alpha(raw, dists, act_fn=F.relu): return 1. - \
+ torch.exp(-(act_fn(raw)+1e-6)*dists)
+
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
+ dists = torch.cat([dists, torch.Tensor([1e10], device=z_vals.device).expand(
+ dists[..., :1].shape)], -1) # [N_rays, N_samples]
+
+ dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
+
+ rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3]
+ rgb = torch.cat((rgb[:, :-1, :], bc_rgb.unsqueeze(1)), dim=1)
+ noise = 0.
+ if raw_noise_std > 0.:
+ noise = torch.randn(raw[..., 3].shape) * raw_noise_std
+
+ # Overwrite randomly sampled data if pytest
+ if pytest:
+ np.random.seed(0)
+ noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std
+ noise = torch.Tensor(noise)
+
+ alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples]
+ weights = alpha * \
+ torch.cumprod(
+ torch.cat([torch.ones((alpha.shape[0], 1), device=alpha.device), 1.-alpha + 1e-10], -1), -1)[:, :-1]
+ rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3]
+
+ rgb_map_fg = torch.sum(weights[:, :-1, None]*rgb[:, :-1, :], -2)
+
+ depth_map = torch.sum(weights * z_vals, -1)
+ disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map),
+ depth_map / torch.sum(weights, -1))
+ acc_map = torch.sum(weights, -1)
+
+ if white_bkgd:
+ rgb_map = rgb_map + (1.-acc_map[..., None])
+
+ return rgb_map, disp_map, acc_map, weights, depth_map, rgb_map_fg
+
+
+def render_rays(ray_batch,
+ bc_rgb,
+ aud_para,
+ network_fn,
+ network_query_fn,
+ N_samples,
+ retraw=False,
+ lindisp=False,
+ perturb=0.,
+ N_importance=0,
+ network_fine=None,
+ white_bkgd=False,
+ raw_noise_std=0.,
+ verbose=False,
+ pytest=False):
+ """Volumetric rendering.
+ Args:
+ ray_batch: array of shape [batch_size, ...]. All information necessary
+ for sampling along a ray, including: ray origin, ray direction, min
+ dist, max dist, and unit-magnitude viewing direction.
+ network_fn: function. Model for predicting RGB and density at each point
+ in space.
+ network_query_fn: function used for passing queries to network_fn.
+ N_samples: int. Number of different times to sample along each ray.
+ retraw: bool. If True, include model's raw, unprocessed predictions.
+ lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
+ perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
+ random points in time.
+ N_importance: int. Number of additional times to sample along each ray.
+ These samples are only passed to network_fine.
+ network_fine: "fine" network with same spec as network_fn.
+ white_bkgd: bool. If True, assume a white background.
+ raw_noise_std: ...
+ verbose: bool. If True, print more debugging info.
+ Returns:
+ rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
+ disp_map: [num_rays]. Disparity map. 1 / depth.
+ acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
+ raw: [num_rays, num_samples, 4]. Raw predictions from model.
+ rgb0: See rgb_map. Output for coarse model.
+ disp0: See disp_map. Output for coarse model.
+ acc0: See acc_map. Output for coarse model.
+ z_std: [num_rays]. Standard deviation of distances along ray for each
+ sample.
+ """
+ N_rays = ray_batch.shape[0]
+ rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each
+ viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None
+ bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2])
+ near, far = bounds[..., 0], bounds[..., 1] # [-1,1]
+
+ t_vals = torch.linspace(0., 1., steps=N_samples, device=rays_o.device)
+ if not lindisp:
+ z_vals = near * (1.-t_vals) + far * (t_vals)
+ else:
+ z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
+
+ z_vals = z_vals.expand([N_rays, N_samples])
+
+ if perturb > 0.:
+ # get intervals between samples
+ mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
+ upper = torch.cat([mids, z_vals[..., -1:]], -1)
+ lower = torch.cat([z_vals[..., :1], mids], -1)
+ # stratified samples in those intervals
+ t_rand = torch.rand(z_vals.shape, device=rays_o.device)
+
+ # Pytest, overwrite u with numpy's fixed random numbers
+ if pytest:
+ np.random.seed(0)
+ t_rand = np.random.rand(*list(z_vals.shape))
+ t_rand = torch.Tensor(t_rand).to(rays_o.device)
+ t_rand[..., -1] = 1.0
+ z_vals = lower + (upper - lower) * t_rand
+ pts = rays_o[..., None, :] + rays_d[..., None, :] * \
+ z_vals[..., :, None] # [N_rays, N_samples, 3]
+
+
+# raw = run_network(pts)
+ raw = network_query_fn(pts, viewdirs, aud_para, network_fn)
+ rgb_map, disp_map, acc_map, weights, depth_map, rgb_map_fg = raw2outputs(
+ raw, z_vals, rays_d, bc_rgb, raw_noise_std, white_bkgd, pytest=pytest)
+
+ if N_importance > 0:
+
+ rgb_map_0, disp_map_0, acc_map_0, last_weight_0, rgb_map_fg_0 = \
+ rgb_map, disp_map, acc_map, weights[..., -1], rgb_map_fg
+
+ z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
+ z_samples = sample_pdf(
+ z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.), pytest=pytest)
+ z_samples = z_samples.detach()
+
+ z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
+ pts = rays_o[..., None, :] + rays_d[..., None, :] * \
+ z_vals[..., :, None]
+
+ run_fn = network_fn if network_fine is None else network_fine
+ raw = network_query_fn(pts, viewdirs, aud_para, run_fn)
+
+ rgb_map, disp_map, acc_map, weights, depth_map, rgb_map_fg = raw2outputs(
+ raw, z_vals, rays_d, bc_rgb, raw_noise_std, white_bkgd, pytest=pytest)
+
+ ret = {'rgb_map': rgb_map, 'disp_map': disp_map,
+ 'acc_map': acc_map, 'rgb_map_fg': rgb_map_fg}
+ if retraw:
+ ret['raw'] = raw
+ if N_importance > 0:
+ ret['rgb0'] = rgb_map_0
+ ret['disp0'] = disp_map_0
+ ret['acc0'] = acc_map_0
+ ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
+ ret['last_weight'] = weights[..., -1]
+ ret['last_weight0'] = last_weight_0
+ ret['rgb_map_fg0'] = rgb_map_fg_0
+
+ for k in ret:
+ if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
+ print(f"! [Numerical Error] {k} contains nan or inf.")
+
+ return ret
+
+
+def config_parser():
+
+ import configargparse
+ parser = configargparse.ArgumentParser()
+ parser.add_argument('--config', is_config_file=True,
+ help='config file path')
+ parser.add_argument("--expname", type=str,
+ help='experiment name')
+ parser.add_argument("--basedir", type=str, default='./logs/',
+ help='where to store ckpts and logs')
+ parser.add_argument("--datadir", type=str, default='./data/llff/fern',
+ help='input data directory')
+
+ # training options
+ parser.add_argument("--netdepth", type=int, default=8,
+ help='layers in network')
+ parser.add_argument("--netwidth", type=int, default=256,
+ help='channels per layer')
+ parser.add_argument("--netdepth_fine", type=int, default=8,
+ help='layers in fine network')
+ parser.add_argument("--netwidth_fine", type=int, default=256,
+ help='channels per layer in fine network')
+ parser.add_argument("--N_rand", type=int, default=1024,
+ help='batch size (number of random rays per gradient step)')
+ parser.add_argument("--lrate", type=float, default=5e-4,
+ help='learning rate')
+ parser.add_argument("--lrate_decay", type=int, default=250,
+ help='exponential learning rate decay (in 1000 steps)')
+ parser.add_argument("--chunk", type=int, default=1024,
+ help='number of rays processed in parallel, decrease if running out of memory')
+ parser.add_argument("--netchunk", type=int, default=1024*64,
+ help='number of pts sent through network in parallel, decrease if running out of memory')
+ parser.add_argument("--no_batching", action='store_false',
+ help='only take random rays from 1 image at a time')
+ parser.add_argument("--no_reload", action='store_true',
+ help='do not reload weights from saved ckpt')
+ parser.add_argument("--ft_path", type=str, default=None,
+ help='specific weights npy file to reload for coarse network')
+ parser.add_argument("--N_iters", type=int, default=400000,
+ help='number of iterations')
+
+ # rendering options
+ parser.add_argument("--N_samples", type=int, default=64,
+ help='number of coarse samples per ray')
+ parser.add_argument("--N_importance", type=int, default=128,
+ help='number of additional fine samples per ray')
+ parser.add_argument("--perturb", type=float, default=1.,
+ help='set to 0. for no jitter, 1. for jitter')
+ parser.add_argument("--use_viewdirs", action='store_false',
+ help='use full 5D input instead of 3D')
+ parser.add_argument("--i_embed", type=int, default=0,
+ help='set 0 for default positional encoding, -1 for none')
+ parser.add_argument("--multires", type=int, default=10,
+ help='log2 of max freq for positional encoding (3D location)')
+ parser.add_argument("--multires_views", type=int, default=4,
+ help='log2 of max freq for positional encoding (2D direction)')
+ parser.add_argument("--raw_noise_std", type=float, default=0.,
+ help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
+
+ parser.add_argument("--render_only", action='store_true',
+ help='do not optimize, reload weights and render out render_poses path')
+ parser.add_argument("--render_test", action='store_true',
+ help='render the test set instead of render_poses path')
+ parser.add_argument("--render_factor", type=int, default=0,
+ help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
+
+ # training options
+ parser.add_argument("--precrop_iters", type=int, default=0,
+ help='number of steps to train on central crops')
+ parser.add_argument("--precrop_frac", type=float,
+ default=.5, help='fraction of img taken for central crops')
+
+ # dataset options
+ parser.add_argument("--dataset_type", type=str, default='audface',
+ help='options: llff / blender / deepvoxels')
+ parser.add_argument("--testskip", type=int, default=1,
+ help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
+
+ # deepvoxels flags
+ parser.add_argument("--shape", type=str, default='greek',
+ help='options : armchair / cube / greek / vase')
+
+ # blender flags
+ parser.add_argument("--white_bkgd", action='store_false',
+ help='set to render synthetic data on a white bkgd (always use for dvoxels)')
+ parser.add_argument("--half_res", action='store_true',
+ help='load blender synthetic data at 400x400 instead of 800x800')
+
+ # face flags
+ parser.add_argument("--with_test", type=int, default=0,
+ help='whether to test')
+ parser.add_argument("--dim_aud", type=int, default=64,
+ help='dimension of audio features for NeRF')
+ parser.add_argument("--dim_aud_body", type=int, default=64,
+ help='dimension of audio features for NeRF')
+ parser.add_argument("--sample_rate", type=float, default=0.95,
+ help="sample rate in a bounding box")
+ parser.add_argument("--near", type=float, default=0.3,
+ help="near sampling plane")
+ parser.add_argument("--far", type=float, default=0.9,
+ help="far sampling plane")
+ parser.add_argument("--test_pose_file", type=str, default='transforms_train.json',
+ help='test pose file')
+ parser.add_argument("--aud_file", type=str, default='aud.npy',
+ help='test audio deepspeech file')
+ parser.add_argument("--win_size", type=int, default=16,
+ help="windows size of audio feature")
+ parser.add_argument("--smo_size", type=int, default=8,
+ help="window size for smoothing audio features")
+ parser.add_argument('--test_size', type=int, default=-1,
+ help='test size')
+ parser.add_argument('--aud_start', type=int, default=0,
+ help='test audio start pos')
+ parser.add_argument('--test_save_folder', type=str, default='test_aud_rst',
+ help='folder to store test result')
+
+ # llff flags
+ parser.add_argument("--factor", type=int, default=8,
+ help='downsample factor for LLFF images')
+ parser.add_argument("--no_ndc", action='store_true',
+ help='do not use normalized device coordinates (set for non-forward facing scenes)')
+ parser.add_argument("--lindisp", action='store_true',
+ help='sampling linearly in disparity rather than depth')
+ parser.add_argument("--spherify", action='store_true',
+ help='set for spherical 360 scenes')
+ parser.add_argument("--llffhold", type=int, default=8,
+ help='will take every 1/N images as LLFF test set, paper uses 8')
+
+ # logging/saving options
+ parser.add_argument("--i_print", type=int, default=100,
+ help='frequency of console printout and metric loggin')
+ parser.add_argument("--i_img", type=int, default=500,
+ help='frequency of tensorboard image logging')
+ parser.add_argument("--i_weights", type=int, default=10000,
+ help='frequency of weight ckpt saving')
+ parser.add_argument("--i_testset", type=int, default=10000,
+ help='frequency of testset saving')
+ parser.add_argument("--i_video", type=int, default=50000,
+ help='frequency of render_poses video saving')
+
+ return parser
+
+
+def train():
+
+ parser = config_parser()
+ args = parser.parse_args()
+
+ # Load data
+ if args.with_test == 1:
+ poses, auds, bc_img, hwfcxy, aud_ids, torso_pose = \
+ load_test_data(args.datadir, args.aud_file,
+ args.test_pose_file, args.testskip, args.test_size, args.aud_start)
+ torso_pose = torch.as_tensor(torso_pose).to(device_torso).float()
+ com_images = np.zeros(1)
+ else:
+ com_images, poses, auds, bc_img, hwfcxy, sample_rects, \
+ i_split = load_audface_data(args.datadir, args.testskip)
+
+ if args.with_test == 0:
+ i_train, i_val = i_split
+
+ near = args.near
+ far = args.far
+
+ # Cast intrinsics to right types
+ H, W, focal, cx, cy = hwfcxy
+ H, W = int(H), int(W)
+ hwf = [H, W, focal]
+ hwfcxy = [H, W, focal, cx, cy]
+
+ # Create log dir and copy the config file
+ basedir = args.basedir
+ expname = args.expname
+ os.makedirs(os.path.join(basedir, expname), exist_ok=True)
+ f = os.path.join(basedir, expname, 'args.txt')
+ with open(f, 'w') as file:
+ for arg in sorted(vars(args)):
+ attr = getattr(args, arg)
+ file.write('{} = {}\n'.format(arg, attr))
+ if args.config is not None:
+ f = os.path.join(basedir, expname, 'config.txt')
+ with open(f, 'w') as file:
+ file.write(open(args.config, 'r').read())
+
+ # Create nerf model
+ render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, \
+ learned_codes, AudNet_state, optimizer_aud_state, AudAttNet_state = create_nerf(
+ args, 'head.tar', args.dim_aud, device, True)
+ global_step = start
+
+ AudNet = AudioNet(args.dim_aud, args.win_size).to(device)
+ AudAttNet = AudioAttNet().to(device)
+ optimizer_Aud = torch.optim.Adam(
+ params=list(AudNet.parameters()), lr=args.lrate, betas=(0.9, 0.999))
+
+ if AudNet_state is not None:
+ AudNet.load_state_dict(AudNet_state)
+ if AudAttNet_state is not None:
+ print('load audattnet')
+ AudAttNet.load_state_dict(AudAttNet_state)
+ if optimizer_aud_state is not None:
+ optimizer_Aud.load_state_dict(optimizer_aud_state)
+ bds_dict = {
+ 'near': near,
+ 'far': far,
+ }
+ render_kwargs_train.update(bds_dict)
+ render_kwargs_test.update(bds_dict)
+
+ # Move training data to GPU
+ bc_img = torch.Tensor(bc_img).to(device).float()/255.0
+ poses = torch.Tensor(poses).to(device).float()
+ auds = torch.Tensor(auds).to(device).float()
+
+ num_frames = com_images.shape[0]
+
+ embed_fn, input_ch = get_embedder(3, 0)
+ dim_torso_signal = args.dim_aud_body + 2*input_ch
+ # Create torso nerf model
+ render_kwargs_train_torso, render_kwargs_test_torso, start, grad_vars_torso, optimizer_torso, \
+ learned_codes_torso, AudNet_state_torso, optimizer_aud_state_torso, _ = create_nerf(
+ args, 'body.tar', dim_torso_signal, device_torso)
+ global_step = start
+
+ AudNet_torso = AudioNet(args.dim_aud_body, args.win_size).to(device_torso)
+ optimizer_Aud_torso = torch.optim.Adam(
+ params=list(AudNet_torso.parameters()), lr=args.lrate, betas=(0.9, 0.999))
+
+ if AudNet_state_torso is not None:
+ AudNet_torso.load_state_dict(AudNet_state_torso)
+ if optimizer_aud_state_torso is not None:
+ optimizer_Aud_torso.load_state_dict(optimizer_aud_state_torso)
+ bds_dict = {
+ 'near': near,
+ 'far': far,
+ }
+ render_kwargs_train_torso.update(bds_dict)
+ render_kwargs_test_torso.update(bds_dict)
+
+ if args.with_test:
+ print('RENDER ONLY')
+ with torch.no_grad():
+ testsavedir = os.path.join(basedir, expname, args.test_save_folder)
+ os.makedirs(testsavedir, exist_ok=True)
+ print('test poses shape', poses.shape)
+ smo_half_win = int(args.smo_size / 2)
+ auds_val = []
+ for i in range(poses.shape[0]):
+ left_i = i - smo_half_win
+ right_i = i + smo_half_win
+ pad_left, pad_right = 0, 0
+ if left_i < 0:
+ pad_left = -left_i
+ left_i = 0
+ if right_i > poses.shape[0]:
+ pad_right = right_i - poses.shape[0]
+ right_i = poses.shape[0]
+ auds_win = auds[left_i:right_i]
+ if pad_left > 0:
+ auds_win = torch.cat(
+ (torch.zeros_like(auds_win)[:pad_left], auds_win), dim=0)
+ if pad_right > 0:
+ auds_win = torch.cat(
+ (auds_win, torch.zeros_like(auds_win)[:pad_right]), dim=0)
+ auds_win = AudNet(auds_win)
+ aud_smo = AudAttNet(auds_win)
+ auds_val.append(aud_smo)
+ auds_val = torch.stack(auds_val, 0)
+
+ adjust_poses = poses.clone()
+ adjust_poses_torso = poses.clone()
+
+ et = pose_to_euler_trans(adjust_poses_torso)
+ embed_et = torch.cat(
+ (embed_fn(et[:, :3]), embed_fn(et[:, 3:])), dim=-1).to(device_torso)
+ signal = torch.cat((auds_val[..., :args.dim_aud_body].to(
+ device_torso), embed_et.squeeze()), dim=-1)
+ t_start = time.time()
+ vid_out = cv2.VideoWriter(os.path.join(testsavedir, 'result.avi'),
+ cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 25, (W, H))
+ for j in range(poses.shape[0]):
+ rgbs, disps, last_weights, rgb_fgs = \
+ render_path(adjust_poses[j:j+1], auds_val[j:j+1],
+ bc_img, hwfcxy, args.chunk, render_kwargs_test)
+ rgbs_torso, disps_torso, last_weights_torso, rgb_fgs_torso = \
+ render_path(torso_pose.unsqueeze(
+ 0), signal[j:j+1], bc_img.to(device_torso), hwfcxy, args.chunk, render_kwargs_test_torso)
+ rgbs_com = rgbs*last_weights_torso[..., None] + rgb_fgs_torso
+ rgb8 = to8b(rgbs_com[0])
+ vid_out.write(rgb8[:, :, ::-1])
+ filename = os.path.join(
+ testsavedir, str(aud_ids[j]) + '.jpg')
+ imageio.imwrite(filename, rgb8)
+ print('finished render', j)
+ print('finished render in', time.time()-t_start)
+ vid_out.release()
+ return
+
+ N_rand = args.N_rand
+ use_batching = not args.no_batching
+ if use_batching:
+ # For random ray batching
+ print('get rays')
+ rays = np.stack([get_rays_np(H, W, focal, p, cx, cy)
+ for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3]
+ print('done, concats')
+ # [N, ro+rd+rgb, H, W, 3]
+ rays_rgb = np.concatenate([rays, com_images[:, None]], 1)
+ # [N, H, W, ro+rd+rgb, 3]
+ rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4])
+ rays_rgb = np.stack([rays_rgb[i]
+ for i in i_train], 0) # train images only
+ # [(N-1)*H*W, ro+rd+rgb, 3]
+ rays_rgb = np.reshape(rays_rgb, [-1, 3, 3])
+ rays_rgb = rays_rgb.astype(np.float32)
+ print('shuffle rays')
+ np.random.shuffle(rays_rgb)
+
+ print('done')
+ i_batch = 0
+
+ if use_batching:
+ rays_rgb = torch.Tensor(rays_rgb).to(device)
+
+ N_iters = args.N_iters + 1
+ print('Begin')
+ print('TRAIN views are', i_train)
+ print('VAL views are', i_val)
+
+ start = start + 1
+ for i in trange(start, N_iters):
+ time0 = time.time()
+
+ # Sample random ray batch
+ if use_batching:
+ # Random over all images
+ batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
+ batch = torch.transpose(batch, 0, 1)
+ batch_rays, target_s = batch[:2], batch[2]
+
+ i_batch += N_rand
+ if i_batch >= rays_rgb.shape[0]:
+ print("Shuffle data after an epoch!")
+ rand_idx = torch.randperm(rays_rgb.shape[0])
+ rays_rgb = rays_rgb[rand_idx]
+ i_batch = 0
+
+ else:
+ # Random from one image
+ img_i = np.random.choice(i_train)
+ target_com = torch.as_tensor(imageio.imread(
+ com_images[img_i])).to(device).float()/255.0
+ pose = poses[img_i, :3, :4]
+ pose_torso = poses[0, :3, :4].to(device_torso)
+ rect = sample_rects[img_i]
+ aud = auds[img_i]
+
+ smo_half_win = int(args.smo_size/2)
+ left_i = img_i - smo_half_win
+ right_i = img_i + smo_half_win
+ pad_left, pad_right = 0, 0
+ if left_i < 0:
+ pad_left = -left_i
+ left_i = 0
+ if right_i > i_train.shape[0]:
+ pad_right = right_i-i_train.shape[0]
+ right_i = i_train.shape[0]
+ auds_win = auds[left_i:right_i]
+ if pad_left > 0:
+ auds_win = torch.cat(
+ (torch.zeros_like(auds_win)[:pad_left], auds_win), dim=0)
+ if pad_right > 0:
+ auds_win = torch.cat(
+ (auds_win, torch.zeros_like(auds_win)[:pad_right]), dim=0)
+ auds_win = AudNet(auds_win)
+ aud_smo = AudAttNet(auds_win)
+ aud_smo_torso = aud_smo.to(device_torso)[..., :args.dim_aud_body]
+
+ et = pose_to_euler_trans(poses[img_i].unsqueeze(0))
+ embed_et = torch.cat(
+ (embed_fn(et[:, :3]), embed_fn(et[:, 3:])), dim=1).to(device_torso)
+ signal = torch.cat((aud_smo_torso, embed_et.squeeze()), dim=-1)
+ if N_rand is not None:
+ rays_o, rays_d = get_rays(
+ H, W, focal, pose, cx, cy, device) # (H, W, 3), (H, W, 3)
+ rays_o_torso, rays_d_torso = get_rays(
+ H, W, focal, pose_torso, cx, cy, device_torso)
+
+ if i < args.precrop_iters:
+ dH = int(H//2 * args.precrop_frac)
+ dW = int(W//2 * args.precrop_frac)
+ coords = torch.stack(
+ torch.meshgrid(
+ torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH),
+ torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
+ ), -1)
+ if i == start:
+ print(
+ f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
+ else:
+ coords = torch.stack(torch.meshgrid(torch.linspace(
+ 0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
+
+ coords = torch.reshape(coords, [-1, 2]) # (H * W, 2)
+ if args.sample_rate > 0:
+ rect = [0, H/2, W, H/2]
+ rect_inds = (coords[:, 0] >= rect[0]) & (
+ coords[:, 0] <= rect[0] + rect[2]) & (
+ coords[:, 1] >= rect[1]) & (
+ coords[:, 1] <= rect[1] + rect[3])
+ coords_rect = coords[rect_inds]
+ coords_norect = coords[~rect_inds]
+ rect_num = int(N_rand*float(rect[2])*rect[3]/H/W)
+ norect_num = N_rand - rect_num
+ select_inds_rect = np.random.choice(
+ coords_rect.shape[0], size=[rect_num], replace=False) # (N_rand,)
+ # (N_rand, 2)
+ select_coords_rect = coords_rect[select_inds_rect].long()
+ select_inds_norect = np.random.choice(
+ coords_norect.shape[0], size=[norect_num], replace=False) # (N_rand,)
+ # (N_rand, 2)
+ select_coords_norect = coords_norect[select_inds_norect].long(
+ )
+ select_coords = torch.cat(
+ (select_coords_norect, select_coords_rect), dim=0)
+
+ else:
+ select_inds = np.random.choice(
+ coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
+ select_coords = coords[select_inds].long()
+ norect_num = 0
+
+ rays_o = rays_o[select_coords[:, 0],
+ select_coords[:, 1]] # (N_rand, 3)
+ rays_d = rays_d[select_coords[:, 0],
+ select_coords[:, 1]] # (N_rand, 3)
+ batch_rays = torch.stack([rays_o, rays_d], 0)
+ bc_rgb = bc_img[select_coords[:, 0],
+ select_coords[:, 1]]
+
+ rays_o_torso = rays_o_torso[select_coords[:, 0],
+ select_coords[:, 1]] # (N_rand, 3)
+ rays_d_torso = rays_d_torso[select_coords[:, 0],
+ select_coords[:, 1]] # (N_rand, 3)
+ batch_rays_torso = torch.stack([rays_o_torso, rays_d_torso], 0)
+ bc_rgb = bc_img[select_coords[:, 0],
+ select_coords[:, 1]]
+ bc_rgb_torso = bc_rgb.to(device_torso)
+
+ target_s_com = target_com[select_coords[:, 0],
+ select_coords[:, 1]] # (N_rand, 3)
+
+ ##### Core optimization loop #####
+ rgb, disp, acc, last_weight, rgb_fg, extras = \
+ render_dynamic_face(H, W, focal, cx, cy, chunk=args.chunk, rays=batch_rays,
+ aud_para=aud_smo, bc_rgb=bc_rgb,
+ verbose=i < 10, retraw=True,
+ ** render_kwargs_train)
+ rgb_torso, disp_torso, acc_torso, last_weight_torso, rgb_fg_torso, extras_torso = \
+ render_dynamic_face(H, W, focal, cx, cy, chunk=args.chunk, rays=batch_rays_torso,
+ aud_para=signal, bc_rgb=bc_rgb_torso,
+ verbose=i < 10, retraw=True,
+ **render_kwargs_train_torso)
+ rgb_com = rgb * \
+ last_weight_torso.to(device)[..., None] + rgb_fg_torso.to(device)
+
+ optimizer_torso.zero_grad()
+ img_loss_com = img2mse(rgb_com, target_s_com)
+ trans = extras['raw'][..., -1]
+ split_weight = float(1.0)
+ loss = img_loss_com
+ psnr = mse2psnr(img_loss_com)
+
+ if 'rgb0' in extras_torso:
+ rgb_com0 = extras['rgb0'] * \
+ extras_torso['last_weight0'].to(
+ device)[..., None] + extras_torso['rgb_map_fg0'].to(device)
+ img_loss0 = img2mse(rgb_com0, target_s_com)
+ loss = loss + img_loss0
+
+ loss.backward()
+ optimizer_torso.step()
+
+ # NOTE: IMPORTANT!
+ ### update learning rate ###
+ decay_rate = 0.1
+ decay_steps = args.lrate_decay * 1000
+ new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
+ #print('cur_rate', new_lrate)
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = new_lrate
+
+ for param_group in optimizer_Aud.param_groups:
+ param_group['lr'] = new_lrate
+
+ for param_group in optimizer_torso.param_groups:
+ param_group['lr'] = new_lrate
+
+ for param_group in optimizer_Aud_torso.param_groups:
+ param_group['lr'] = new_lrate
+ ################################
+
+ dt = time.time()-time0
+
+ # Rest is logging
+ if i % args.i_weights == 0:
+ path = os.path.join(basedir, expname, '{:06d}_head.tar'.format(i))
+ torch.save({
+ 'global_step': global_step,
+ 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
+ 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
+ 'network_audnet_state_dict': AudNet.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'optimizer_aud_state_dict': optimizer_Aud.state_dict(),
+ 'network_audattnet_state_dict': AudAttNet.state_dict(),
+ }, path)
+
+ path = os.path.join(basedir, expname, '{:06d}_body.tar'.format(i))
+ torch.save({
+ 'global_step': global_step,
+ 'network_fn_state_dict': render_kwargs_train_torso['network_fn'].state_dict(),
+ 'network_fine_state_dict': render_kwargs_train_torso['network_fine'].state_dict(),
+ 'network_audnet_state_dict': AudNet_torso.state_dict(),
+ 'optimizer_state_dict': optimizer_torso.state_dict(),
+ 'optimizer_aud_state_dict': optimizer_Aud_torso.state_dict(),
+ }, path)
+ print('Saved checkpoints at', path)
+
+ if i % args.i_testset == 0 and i > 0:
+ testsavedir = os.path.join(
+ basedir, expname, 'testset_{:06d}'.format(i))
+ os.makedirs(testsavedir, exist_ok=True)
+ print('test poses shape', poses[i_val].shape)
+
+ aud_torso = AudNet(
+ auds[i_val])[..., :args.dim_aud_body].to(device_torso)
+ et = pose_to_euler_trans(poses[i_val])
+ embed_et = torch.cat(
+ (embed_fn(et[:, :3]), embed_fn(et[:, 3:])), dim=1).to(device_torso)
+ signal = torch.cat((aud_torso, embed_et.squeeze()), dim=-1)
+
+ auds_val = AudNet(auds[i_val])
+ with torch.no_grad():
+ for j in range(auds_val.shape[0]):
+ rgbs, disps, last_weights, rgb_fgs = \
+ render_path(poses[i_val][j:j+1], auds_val[j:j+1],
+ bc_img, hwfcxy, args.chunk, render_kwargs_test)
+ rgbs_torso, disps_torso, last_weights_torso, rgb_fgs_torso = \
+ render_path(poses[0].to(device_torso).unsqueeze(0),
+ signal[j:j+1], bc_img.to(
+ device_torso), hwfcxy, args.chunk, render_kwargs_test_torso)
+ rgbs_com = rgbs * \
+ last_weights_torso[..., None] + rgb_fgs_torso
+ rgb8 = to8b(rgbs_com[0])
+ filename = os.path.join(
+ testsavedir, '{:03d}.jpg'.format(j))
+ imageio.imwrite(filename, rgb8)
+ print('Saved test set')
+
+ if i % args.i_print == 0:
+ tqdm.write(
+ f"[TRAIN] Iter: {i} Loss: {img_loss_com.item()} PSNR: {psnr.item()}")
+
+ global_step += 1
+
+
+if __name__ == '__main__':
+ torch.set_default_tensor_type('torch.cuda.FloatTensor')
+
+ train()
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf_helpers.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf_helpers.py
new file mode 100644
index 00000000..e47025fb
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf_helpers.py
@@ -0,0 +1,454 @@
+import numpy as np
+import torch.nn.functional as F
+import torch.nn as nn
+import torch
+torch.autograd.set_detect_anomaly(True)
+
+# TODO: remove this dependency
+
+
+# Misc
+def img2mse(x, y, num=0):
+ if num > 0:
+ return torch.mean((x[num] - y[num]) ** 2)
+ else:
+ return torch.mean((x - y) ** 2)
+
+
+def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
+
+
+def to8b(x): return (255*np.clip(x, 0, 1)).astype(np.uint8)
+
+
+# Positional encoding (section 5.1)
+class Embedder:
+ def __init__(self, **kwargs):
+ self.kwargs = kwargs
+ self.create_embedding_fn()
+
+ def create_embedding_fn(self):
+ embed_fns = []
+ d = self.kwargs['input_dims']
+ device = self.kwargs['device']
+ out_dim = 0
+ if self.kwargs['include_input']:
+ embed_fns.append(lambda x: x)
+ out_dim += d
+
+ max_freq = self.kwargs['max_freq_log2']
+ N_freqs = self.kwargs['num_freqs']
+
+ if self.kwargs['log_sampling']:
+ freq_bands = 2.**torch.linspace(0.,
+ max_freq, steps=N_freqs, device=device)
+ else:
+ freq_bands = torch.linspace(
+ 2.**0., 2.**max_freq, steps=N_freqs, device=device)
+
+ for freq in freq_bands:
+ for p_fn in self.kwargs['periodic_fns']:
+ embed_fns.append(lambda x, p_fn=p_fn,
+ freq=freq: p_fn(x * freq))
+ out_dim += d
+
+ self.embed_fns = embed_fns
+ self.out_dim = out_dim
+
+ def embed(self, inputs):
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
+
+
+def get_embedder(multires, i=0, device=torch.device('cuda', 0)):
+ if i == -1:
+ return nn.Identity(), 3
+
+ embed_kwargs = {
+ 'include_input': True,
+ 'input_dims': 3,
+ 'max_freq_log2': multires-1,
+ 'num_freqs': multires,
+ 'log_sampling': True,
+ 'device': device,
+ 'periodic_fns': [torch.sin, torch.cos],
+ }
+
+ embedder_obj = Embedder(**embed_kwargs)
+ def embed(x, eo=embedder_obj): return eo.embed(x)
+ return embed, embedder_obj.out_dim
+
+
+# Audio feature extractor
+class AudioAttNet(nn.Module):
+ def __init__(self, dim_aud=32, seq_len=8):
+ super(AudioAttNet, self).__init__()
+ self.seq_len = seq_len
+ self.dim_aud = dim_aud
+ self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len
+ nn.Conv1d(self.dim_aud, 16, kernel_size=3,
+ stride=1, padding=1, bias=True),
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True),
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True),
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True),
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True),
+ nn.LeakyReLU(0.02, True)
+ )
+ self.attentionNet = nn.Sequential(
+ nn.Linear(in_features=self.seq_len,
+ out_features=self.seq_len, bias=True),
+ nn.Softmax(dim=1)
+ )
+
+ def forward(self, x):
+ y = x[..., :self.dim_aud].permute(1, 0).unsqueeze(
+ 0) # 2 x subspace_dim x seq_len
+ y = self.attentionConvNet(y)
+ y = self.attentionNet(y.view(1, self.seq_len)).view(self.seq_len, 1)
+ #print(y.view(-1).data)
+ return torch.sum(y*x, dim=0)
+# Model
+
+# Audio feature extractor
+
+
+# Audio feature extractor
+class AudioNet(nn.Module):
+ def __init__(self, dim_aud=76, win_size=16):
+ super(AudioNet, self).__init__()
+ self.win_size = win_size
+ self.dim_aud = dim_aud
+ self.encoder_conv = nn.Sequential( # n x 29 x 16
+ nn.Conv1d(29, 32, kernel_size=3, stride=2,
+ padding=1, bias=True), # n x 32 x 8
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(32, 32, kernel_size=3, stride=2,
+ padding=1, bias=True), # n x 32 x 4
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(32, 64, kernel_size=3, stride=2,
+ padding=1, bias=True), # n x 64 x 2
+ nn.LeakyReLU(0.02, True),
+ nn.Conv1d(64, 64, kernel_size=3, stride=2,
+ padding=1, bias=True), # n x 64 x 1
+ nn.LeakyReLU(0.02, True),
+ )
+ self.encoder_fc1 = nn.Sequential(
+ nn.Linear(64, 64),
+ nn.LeakyReLU(0.02, True),
+ nn.Linear(64, dim_aud),
+ )
+
+ def forward(self, x):
+ half_w = int(self.win_size/2)
+ x = x[:, 8-half_w:8+half_w, :].permute(0, 2, 1)
+ x = self.encoder_conv(x).squeeze(-1)
+ x = self.encoder_fc1(x).squeeze()
+ return x
+# Model
+
+
+class FaceNeRF(nn.Module):
+ def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, dim_aud=32,
+ output_ch=4, skips=[4], use_viewdirs=False):
+ """
+ """
+ super(FaceNeRF, self).__init__()
+ self.D = D
+ self.W = W
+ self.input_ch = input_ch
+ self.input_ch_views = input_ch_views
+ self.dim_aud = dim_aud
+ self.skips = skips
+ self.use_viewdirs = use_viewdirs
+
+ input_ch_all = input_ch + dim_aud
+ self.pts_linears = nn.ModuleList(
+ [nn.Linear(input_ch_all, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch_all, W) for i in range(D-1)])
+
+ # Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
+ # self.views_linears = nn.ModuleList(
+ # [nn.Linear(input_ch_views + W, W//2)])
+
+ # Implementation according to the paper
+ self.views_linears = nn.ModuleList(
+ [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//4)])
+
+ if use_viewdirs:
+ self.feature_linear = nn.Linear(W, W)
+ self.alpha_linear = nn.Linear(W, 1)
+ self.rgb_linear = nn.Linear(W//2, 3)
+ else:
+ self.output_linear = nn.Linear(W, output_ch)
+
+ def forward(self, x):
+ input_pts, input_views = torch.split(
+ x, [self.input_ch+self.dim_aud, self.input_ch_views], dim=-1)
+ h = input_pts
+ for i, l in enumerate(self.pts_linears):
+ h = self.pts_linears[i](h)
+ h = F.relu(h)
+ if i in self.skips:
+ h = torch.cat([input_pts, h], -1)
+
+ if self.use_viewdirs:
+ alpha = self.alpha_linear(h)
+ feature = h # self.feature_linear(h)
+ h = torch.cat([feature, input_views], -1)
+
+ for i, l in enumerate(self.views_linears):
+ h = self.views_linears[i](h)
+ h = F.relu(h)
+
+ rgb = self.rgb_linear(h)
+ outputs = torch.cat([rgb, alpha], -1)
+ else:
+ outputs = self.output_linear(h)
+
+ return outputs
+
+ def load_weights_from_keras(self, weights):
+ assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
+
+ # Load pts_linears
+ for i in range(self.D):
+ idx_pts_linears = 2 * i
+ self.pts_linears[i].weight.data = torch.from_numpy(
+ np.transpose(weights[idx_pts_linears]))
+ self.pts_linears[i].bias.data = torch.from_numpy(
+ np.transpose(weights[idx_pts_linears+1]))
+
+ # Load feature_linear
+ idx_feature_linear = 2 * self.D
+ self.feature_linear.weight.data = torch.from_numpy(
+ np.transpose(weights[idx_feature_linear]))
+ self.feature_linear.bias.data = torch.from_numpy(
+ np.transpose(weights[idx_feature_linear+1]))
+
+ # Load views_linears
+ idx_views_linears = 2 * self.D + 2
+ self.views_linears[0].weight.data = torch.from_numpy(
+ np.transpose(weights[idx_views_linears]))
+ self.views_linears[0].bias.data = torch.from_numpy(
+ np.transpose(weights[idx_views_linears+1]))
+
+ # Load rgb_linear
+ idx_rbg_linear = 2 * self.D + 4
+ self.rgb_linear.weight.data = torch.from_numpy(
+ np.transpose(weights[idx_rbg_linear]))
+ self.rgb_linear.bias.data = torch.from_numpy(
+ np.transpose(weights[idx_rbg_linear+1]))
+
+ # Load alpha_linear
+ idx_alpha_linear = 2 * self.D + 6
+ self.alpha_linear.weight.data = torch.from_numpy(
+ np.transpose(weights[idx_alpha_linear]))
+ self.alpha_linear.bias.data = torch.from_numpy(
+ np.transpose(weights[idx_alpha_linear+1]))
+
+
+# Model
+class NeRF(nn.Module):
+ def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False):
+ """
+ """
+ super(NeRF, self).__init__()
+ self.D = D
+ self.W = W
+ self.input_ch = input_ch
+ self.input_ch_views = input_ch_views
+ self.skips = skips
+ self.use_viewdirs = use_viewdirs
+
+ self.pts_linears = nn.ModuleList(
+ [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
+
+ # Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105)
+ self.views_linears = nn.ModuleList(
+ [nn.Linear(input_ch_views + W, W//2)])
+
+ # Implementation according to the paper
+ # self.views_linears = nn.ModuleList(
+ # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)])
+
+ if use_viewdirs:
+ self.feature_linear = nn.Linear(W, W)
+ self.alpha_linear = nn.Linear(W, 1)
+ self.rgb_linear = nn.Linear(W//2, 3)
+ else:
+ self.output_linear = nn.Linear(W, output_ch)
+
+ def forward(self, x):
+ input_pts, input_views = torch.split(
+ x, [self.input_ch, self.input_ch_views], dim=-1)
+ h = input_pts
+ for i, l in enumerate(self.pts_linears):
+ h = self.pts_linears[i](h)
+ h = F.relu(h)
+ if i in self.skips:
+ h = torch.cat([input_pts, h], -1)
+
+ if self.use_viewdirs:
+ alpha = self.alpha_linear(h)
+ feature = self.feature_linear(h)
+ h = torch.cat([feature, input_views], -1)
+
+ for i, l in enumerate(self.views_linears):
+ h = self.views_linears[i](h)
+ h = F.relu(h)
+
+ rgb = self.rgb_linear(h)
+ outputs = torch.cat([rgb, alpha], -1)
+ else:
+ outputs = self.output_linear(h)
+
+ return outputs
+
+ def load_weights_from_keras(self, weights):
+ assert self.use_viewdirs, "Not implemented if use_viewdirs=False"
+
+ # Load pts_linears
+ for i in range(self.D):
+ idx_pts_linears = 2 * i
+ self.pts_linears[i].weight.data = torch.from_numpy(
+ np.transpose(weights[idx_pts_linears]))
+ self.pts_linears[i].bias.data = torch.from_numpy(
+ np.transpose(weights[idx_pts_linears+1]))
+
+ # Load feature_linear
+ idx_feature_linear = 2 * self.D
+ self.feature_linear.weight.data = torch.from_numpy(
+ np.transpose(weights[idx_feature_linear]))
+ self.feature_linear.bias.data = torch.from_numpy(
+ np.transpose(weights[idx_feature_linear+1]))
+
+ # Load views_linears
+ idx_views_linears = 2 * self.D + 2
+ self.views_linears[0].weight.data = torch.from_numpy(
+ np.transpose(weights[idx_views_linears]))
+ self.views_linears[0].bias.data = torch.from_numpy(
+ np.transpose(weights[idx_views_linears+1]))
+
+ # Load rgb_linear
+ idx_rbg_linear = 2 * self.D + 4
+ self.rgb_linear.weight.data = torch.from_numpy(
+ np.transpose(weights[idx_rbg_linear]))
+ self.rgb_linear.bias.data = torch.from_numpy(
+ np.transpose(weights[idx_rbg_linear+1]))
+
+ # Load alpha_linear
+ idx_alpha_linear = 2 * self.D + 6
+ self.alpha_linear.weight.data = torch.from_numpy(
+ np.transpose(weights[idx_alpha_linear]))
+ self.alpha_linear.bias.data = torch.from_numpy(
+ np.transpose(weights[idx_alpha_linear+1]))
+
+
+# Ray helpers
+def get_rays(H, W, focal, c2w, cx=None, cy=None, device_cur=torch.device('cuda', 0)):
+ # pytorch's meshgrid has indexing='ij'
+ i, j = torch.meshgrid(torch.linspace(0, W-1, W, device=device_cur, dtype=torch.float32),
+ torch.linspace(0, H-1, H, device=device_cur, dtype=torch.float32))
+ i = i.t()
+ j = j.t()
+ if cx is None:
+ cx = W*.5
+ if cy is None:
+ cy = H*.5
+ dirs = torch.stack(
+ [(i-cx)/focal, -(j-cy)/focal, -torch.ones_like(i)], -1)
+ # Rotate ray directions from camera frame to the world frame
+ # dot product, equals to: [c2w.dot(dir) for dir in dirs]
+ rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1)
+ # Translate camera frame's origin to the world frame. It is the origin of all rays.
+ rays_o = c2w[:3, -1].expand(rays_d.shape)
+ return rays_o, rays_d
+
+
+def get_rays_np(H, W, focal, c2w, cx=None, cy=None):
+ if cx is None:
+ cx = W*.5
+ if cy is None:
+ cy = H*.5
+ i, j = np.meshgrid(np.arange(W, dtype=np.float32),
+ np.arange(H, dtype=np.float32), indexing='xy')
+ dirs = np.stack([(i-cx)/focal, -(j-cy)/focal, -np.ones_like(i)], -1)
+ # Rotate ray directions from camera frame to the world frame
+ # dot product, equals to: [c2w.dot(dir) for dir in dirs]
+ rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1)
+ # Translate camera frame's origin to the world frame. It is the origin of all rays.
+ rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d))
+ return rays_o, rays_d
+
+
+def ndc_rays(H, W, focal, near, rays_o, rays_d):
+ # Shift ray origins to near plane
+ t = -(near + rays_o[..., 2]) / rays_d[..., 2]
+ rays_o = rays_o + t[..., None] * rays_d
+
+ # Projection
+ o0 = -1./(W/(2.*focal)) * rays_o[..., 0] / rays_o[..., 2]
+ o1 = -1./(H/(2.*focal)) * rays_o[..., 1] / rays_o[..., 2]
+ o2 = 1. + 2. * near / rays_o[..., 2]
+
+ d0 = -1./(W/(2.*focal)) * \
+ (rays_d[..., 0]/rays_d[..., 2] - rays_o[..., 0]/rays_o[..., 2])
+ d1 = -1./(H/(2.*focal)) * \
+ (rays_d[..., 1]/rays_d[..., 2] - rays_o[..., 1]/rays_o[..., 2])
+ d2 = -2. * near / rays_o[..., 2]
+
+ rays_o = torch.stack([o0, o1, o2], -1)
+ rays_d = torch.stack([d0, d1, d2], -1)
+
+ return rays_o, rays_d
+
+
+# Hierarchical sampling (section 5.2)
+def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
+ # Get pdf
+ weights = weights + 1e-5 # prevent nans
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
+ cdf = torch.cumsum(pdf, -1)
+ # (batch, len(bins))
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
+
+ # Take uniform samples
+ if det:
+ u = torch.linspace(0., 1., steps=N_samples, device=cdf.device)
+ u = u.expand(list(cdf.shape[:-1]) + [N_samples])
+ else:
+ u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=cdf.device)
+
+ # Pytest, overwrite u with numpy's fixed random numbers
+ if pytest:
+ np.random.seed(0)
+ new_shape = list(cdf.shape[:-1]) + [N_samples]
+ if det:
+ u = np.linspace(0., 1., N_samples)
+ u = np.broadcast_to(u, new_shape)
+ else:
+ u = np.random.rand(*new_shape)
+ u = torch.Tensor(u)
+
+ # Invert CDF
+ u = u.contiguous()
+ #inds = searchsorted(cdf, u, side='right')
+ inds = torch.searchsorted(cdf, u, right=True)
+ below = torch.max(torch.zeros_like(inds-1), inds-1)
+ above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
+ inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
+
+ # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
+ # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
+
+ denom = (cdf_g[..., 1]-cdf_g[..., 0])
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
+ t = (u-cdf_g[..., 0])/denom
+ samples = bins_g[..., 0] + t * (bins_g[..., 1]-bins_g[..., 0])
+
+ return samples
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/AD_NeRF_master.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/AD_NeRF_master.py
new file mode 100644
index 00000000..b6d59727
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/AD_NeRF_master.py
@@ -0,0 +1,114 @@
+import os
+import torch
+import numpy as np
+import imageio
+import json
+import torch.nn.functional as F
+import cv2
+
+
+def load_audface_data(basedir, testskip=1, test_file=None, aud_file=None, test_size=-1):
+ if test_file is not None:
+ with open(os.path.join(basedir, test_file)) as fp:
+ meta = json.load(fp)
+ poses = []
+ auds = []
+ aud_features = np.load(os.path.join(basedir, aud_file))
+ cur_id = 0
+ for frame in meta['frames'][::testskip]:
+ poses.append(np.array(frame['transform_matrix']))
+ aud_id = cur_id
+ auds.append(aud_features[aud_id])
+ cur_id = cur_id + 1
+ if cur_id == aud_features.shape[0] or cur_id == test_size:
+ break
+ poses = np.array(poses).astype(np.float32)
+ auds = np.array(auds).astype(np.float32)
+ bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg'))
+ H, W = bc_img.shape[0], bc_img.shape[1]
+ focal, cx, cy = float(meta['focal_len']), float(
+ meta['cx']), float(meta['cy'])
+ return poses, auds, bc_img, [H, W, focal, cx, cy]
+
+ splits = ['train', 'val']
+ metas = {}
+ for s in splits:
+ with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp:
+ metas[s] = json.load(fp)
+ all_com_imgs = []
+ all_poses = []
+ all_auds = []
+ all_sample_rects = []
+ aud_features = np.load(os.path.join(basedir, 'aud.npy'))
+ counts = [0]
+ for s in splits:
+ meta = metas[s]
+ com_imgs = []
+ poses = []
+ auds = []
+ sample_rects = []
+ if s == 'train' or testskip == 0:
+ skip = 1
+ else:
+ skip = testskip
+
+ for frame in meta['frames'][::skip]:
+ filename = os.path.join(basedir, 'com_imgs',
+ str(frame['img_id']) + '.jpg')
+ com_imgs.append(filename)
+ poses.append(np.array(frame['transform_matrix']))
+ auds.append(
+ aud_features[min(frame['aud_id'], aud_features.shape[0]-1)])
+ sample_rects.append(np.array(frame['face_rect'], dtype=np.int32))
+ com_imgs = np.array(com_imgs)
+ poses = np.array(poses).astype(np.float32)
+ auds = np.array(auds).astype(np.float32)
+ counts.append(counts[-1] + com_imgs.shape[0])
+ all_com_imgs.append(com_imgs)
+ all_poses.append(poses)
+ all_auds.append(auds)
+ all_sample_rects.append(sample_rects)
+ i_split = [np.arange(counts[i], counts[i+1]) for i in range(len(splits))]
+ com_imgs = np.concatenate(all_com_imgs, 0)
+ poses = np.concatenate(all_poses, 0)
+ auds = np.concatenate(all_auds, 0)
+ sample_rects = np.concatenate(all_sample_rects, 0)
+
+ bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg'))
+
+ H, W = bc_img.shape[:2]
+ focal, cx, cy = float(meta['focal_len']), float(
+ meta['cx']), float(meta['cy'])
+
+ return com_imgs, poses, auds, bc_img, [H, W, focal, cx, cy], \
+ sample_rects, i_split
+
+
+def load_test_data(basedir, aud_file, test_pose_file='transforms_train.json',
+ testskip=1, test_size=-1, aud_start=0):
+ with open(os.path.join(basedir, test_pose_file)) as fp:
+ meta = json.load(fp)
+ poses = []
+ auds = []
+ aud_features = np.load(aud_file)
+ aud_ids = []
+ cur_id = 0
+ for frame in meta['frames'][::testskip]:
+ poses.append(np.array(frame['transform_matrix']))
+ auds.append(
+ aud_features[min(aud_start+cur_id, aud_features.shape[0]-1)])
+ aud_ids.append(aud_start+cur_id)
+ cur_id = cur_id + 1
+ if cur_id == test_size or cur_id == aud_features.shape[0]:
+ break
+ poses = np.array(poses).astype(np.float32)
+ auds = np.array(auds).astype(np.float32)
+ bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg'))
+ H, W = bc_img.shape[0], bc_img.shape[1]
+ focal, cx, cy = float(meta['focal_len']), float(
+ meta['cx']), float(meta['cy'])
+
+ with open(os.path.join(basedir, 'transforms_train.json')) as fp:
+ meta_torso = json.load(fp)
+ torso_pose = np.array(meta_torso['frames'][0]['transform_matrix'])
+ return poses, auds, bc_img, [H, W, focal, cx, cy], aud_ids, torso_pose
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/__init__.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/__init__.py
new file mode 100644
index 00000000..bf675ce3
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/__init__.py
@@ -0,0 +1 @@
+from talkingface.model.nerf_based_talkingface.wav2lip import Wav2Lip, SyncNet_color
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/wav2lip.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/wav2lip.py
new file mode 100644
index 00000000..ff3dfb46
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/wav2lip.py
@@ -0,0 +1,358 @@
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.nn.init import xavier_normal_, constant_
+from tqdm import tqdm
+from os import listdir, path
+import numpy as np
+import os, subprocess
+from glob import glob
+import cv2
+
+from talkingface.model.layers import Conv2d, Conv2dTranspose, nonorm_Conv2d
+from talkingface.model.abstract_talkingface import AbstractTalkingFace
+from talkingface.data.dataprocess.wav2lip_process import Wav2LipPreprocessForInference, Wav2LipAudio
+from talkingface.utils import ensure_dir
+
+class SyncNet_color(nn.Module):
+ def __init__(self):
+ super(SyncNet_color, self).__init__()
+
+ self.face_encoder = nn.Sequential(
+ Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
+
+ Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
+
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
+
+ def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
+ face_embedding = self.face_encoder(face_sequences)
+ audio_embedding = self.audio_encoder(audio_sequences)
+
+ audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
+ face_embedding = face_embedding.view(face_embedding.size(0), -1)
+
+ audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
+ face_embedding = F.normalize(face_embedding, p=2, dim=1)
+
+
+ return audio_embedding, face_embedding
+
+
+
+
+
+
+
+class Wav2Lip(AbstractTalkingFace):
+ """wav2lip is a GAN-based model that predict the final with audio and image"""
+ def __init__(self, config):
+ super(Wav2Lip, self).__init__()
+
+ self.face_encoder_blocks = nn.ModuleList([
+ nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
+
+ nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
+
+ nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
+
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
+
+ self.face_decoder_blocks = nn.ModuleList([
+ nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
+
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
+
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
+
+ nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
+
+ nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
+
+ nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
+
+ nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
+
+ self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
+ nn.Sigmoid())
+ self.config = config
+ self.l1loss = nn.L1Loss()
+ self.bceloss = nn.BCELoss()
+ def forward(self, audio_sequences, face_sequences):
+ # audio_sequences = (B, T, 1, 80, 16)
+ B = audio_sequences.size(0)
+
+ input_dim_size = len(face_sequences.size())
+ if input_dim_size > 4:
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
+
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
+
+ feats = []
+ x = face_sequences
+ for f in self.face_encoder_blocks:
+ x = f(x)
+ feats.append(x)
+
+ x = audio_embedding
+ for f in self.face_decoder_blocks:
+ x = f(x)
+ try:
+ x = torch.cat((x, feats[-1]), dim=1)
+ except Exception as e:
+ print(x.size())
+ print(feats[-1].size())
+ raise e
+
+ feats.pop()
+
+ x = self.output_block(x)
+
+ if input_dim_size > 4:
+ x = torch.split(x, B, dim=0) # [(B, C, H, W)]
+ outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
+
+ else:
+ outputs = x
+
+ return outputs
+
+ def predict(self, audio_sequences, face_sequences):
+ return self.forward(audio_sequences, face_sequences)
+
+ def calculate_loss(self, interaction, valid=False):
+ r"""Calculate the training loss for a batch data.
+
+ Args:
+ interaction (Interaction): Interaction class of the batch.
+
+ Returns:
+ torch.Tensor: Training loss, shape: []
+ """
+ indiv_mels = interaction['indiv_mels'].to(self.config['device'])
+ input_frames = interaction['input_frames'].to(self.config['device'])
+ mel = interaction['mels'].to(self.config['device'])
+ gt = interaction['gt'].to(self.config['device'])
+ g_frames = self.forward(indiv_mels, input_frames)
+ l1loss = self.l1loss(g_frames, gt)
+ if self.config['syncnet_wt'] > 0 or valid:
+ sync_loss = self.syncnet_loss(mel, g_frames)
+ else:
+ sync_loss = 0
+
+ loss = self.config['syncnet_wt'] * sync_loss + (1- self.config['syncnet_wt']) * l1loss
+ return {"loss":loss, "l1loss":l1loss, "sync_loss":sync_loss}
+
+ def syncnet_loss(self, mel, g_frames):
+ syncnet = self.load_syncnet()
+ syncnet.eval()
+ g = g_frames[:, :, :, g_frames.size(3)//2:]
+ g = torch.cat([g[:, :, i] for i in range(self.config['syncnet_T'])], dim=1)
+ # B, 3 * T, H//2, W
+ a, v = syncnet(mel, g)
+ y = torch.ones(g.size(0), 1).float().to(self.config['device'])
+ return self.cosine_loss(a, v, y)
+
+ def cosine_loss(self, a, v, y):
+ d = nn.functional.cosine_similarity(a, v)
+ loss = self.bceloss(d.unsqueeze(1), y)
+
+ return loss
+
+ def load_syncnet(self):
+ syncnet = SyncNet_color().to(self.config['device'])
+ for p in syncnet.parameters():
+ p.requires_grad = False
+ checkpoint = torch.load(self.config["syncnet_checkpoint_path"])
+ s = checkpoint["state_dict"]
+ new_s = {}
+ for k, v in s.items():
+ new_s[k.replace('module.', '')] = v
+ syncnet.load_state_dict(new_s)
+ return syncnet
+
+ def generate_batch(self):
+ audio_processor = Wav2LipAudio(self.config)
+ video_processor = Wav2LipPreprocessForInference(self.config)
+
+ with open(self.config['test_filelist'], 'r') as filelist:
+ lines = filelist.readlines()
+
+ file_dict = {'generated_video': [], 'real_video': []}
+ for idx, line in enumerate(tqdm(lines, desc='generate video')):
+ file_src = line.split()[0]
+
+ audio_src = os.path.join(self.config['data_root'], file_src) + '.mp4'
+ video = os.path.join(self.config['data_root'], file_src) + '.mp4'
+
+
+ ensure_dir(os.path.join(self.config['temp_dir']))
+
+ command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, os.path.join(self.config['temp_dir'], 'temp')+'.wav')
+ subprocess.call(command, shell=True)
+
+ temp_audio = os.path.join(self.config['temp_dir'], 'temp')+'.wav'
+ wav = audio_processor.load_wav(temp_audio, 16000)
+ mel = audio_processor.melspectrogram(wav)
+
+ if np.isnan(mel.reshape(-1)).sum() > 0:
+ continue
+
+ mel_idx_multiplier = 80./self.config['fps']
+ mel_chunks = []
+ i = 0
+ while 1:
+ start_idx = int(i * mel_idx_multiplier)
+ if start_idx + self.config['mel_step_size'] > len(mel[0]):
+ break
+ mel_chunks.append(mel[:, start_idx : start_idx + self.config['mel_step_size']])
+ i += 1
+
+ video_stream = cv2.VideoCapture(video)
+ full_frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading or len(full_frames) > len(mel_chunks):
+ video_stream.release()
+ break
+ full_frames.append(frame)
+
+ if len(full_frames) < len(mel_chunks):
+ continue
+
+ full_frames = full_frames[:len(mel_chunks)]
+
+ try:
+ face_det_results = video_processor.face_detect(full_frames.copy())
+ except ValueError as e:
+ continue
+
+ batch_size = self.config['wav2lip_batch_size']
+ gen = video_processor.datagen(full_frames.copy(), face_det_results, mel_chunks)
+
+ for i, (img_batch, mel_batch, frames, coords) in enumerate(gen):
+ if i == 0:
+ frame_h, frame_w = full_frames[0].shape[:-1]
+ output_video_path = os.path.join(self.config['temp_dir'], 'temp')+'.mp4'
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或者尝试 'avc1'
+ out = cv2.VideoWriter(output_video_path, fourcc, 25, (frame_w, frame_h))
+
+
+ img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.config['device'])
+ mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.config['device'])
+
+ with torch.no_grad():
+ pred = self.predict(mel_batch, img_batch)
+
+
+ pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
+
+ for pl, f, c in zip(pred, frames, coords):
+ y1, y2, x1, x2 = c
+ pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1))
+ f[y1:y2, x1:x2] = pl
+ out.write(f)
+
+ out.release()
+
+ vid = os.path.join(self.config['temp_dir'], file_src) + '.mp4'
+ vid_directory = os.path.dirname(vid)
+ if not os.path.exists(vid_directory):
+ os.makedirs(vid_directory)
+
+ command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format(temp_audio,
+ os.path.join(self.config['temp_dir'], 'temp')+'.mp4', vid)
+ process_status = subprocess.call(command, shell=True)
+ if process_status == 0:
+ file_dict['generated_video'].append(vid)
+ file_dict['real_video'].append(video)
+ else:
+ continue
+ return file_dict
+
diff --git a/talkingface-toolkit-main/talkingface/properties/dataset/lrs2.yaml b/talkingface-toolkit-main/talkingface/properties/dataset/lrs2.yaml
new file mode 100644
index 00000000..3afa074f
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/properties/dataset/lrs2.yaml
@@ -0,0 +1,10 @@
+train_filelist: 'dataset/lrs2/filelist/train.txt' # 当前数据集的数据划分文件 train
+test_filelist: 'dataset/lrs2/filelist/test.txt' # 当前数据集的数据划分文件 test
+val_filelist: 'dataset/lrs2/filelist/val.txt' # 当前数据集的数据划分文件 val
+
+data_root: 'dataset/lrs2/data/main' # 当前数据集的数据根目录
+preprocessed_root: 'dataset/lrs2/preprocessed_data' # 当前数据集的预处理数据根目录
+
+need_preprocess: True # 数据集是否需要预处理,如抽帧、抽音频等
+
+preprocess_batch_size: 32
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master.yaml
new file mode 100644
index 00000000..ef3b2b78
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master.yaml
@@ -0,0 +1,37 @@
+model_params:
+ APC:
+ ckp_path: './checkpoints/live_speech_portraits/APC_epoch_160.model'
+ mel_dim: 80
+ hidden_size: 512
+ num_layers: 3
+ residual: false
+ use_LLE: 1
+ Knear: 10
+ LLE_percent: 1
+ Audio2Mouth:
+ ckp_path: './checkpoints/live_speech_portraits/McStay/checkpoints/Audio2Feature.pkl'
+ smooth: 2
+ AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z
+ Headpose:
+ ckp_path: './checkpoints/live_speech_portraits/McStay/checkpoints/Audio2Headpose.pkl'
+ sigma: 0.3
+ smooth: [5, 10] # rot, trans
+ AMP: [1, 1] # rot, trans
+ shoulder_AMP: 0.5
+ Image2Image:
+ ckp_path: './checkpoints/live_speech_portraits/McStay/checkpoints/Feature2Face.pkl'
+ size: 'normal'
+ save_input: 1
+
+#我自己加的
+checkpoint_sub_dir: "/live_speech_portraits" # 和overall.yaml里checkpoint_dir拼起来作为最终目录
+temp_sub_dir: "/live_speech_portraits" # 和overall.yaml里temp_dir拼起来作为最终目录
+driving_audio_path: './checkpoints/live_speech_portraits/Input/00083.wav' #驱动音频路径
+save_intermediates: 0 #是否存储中间文件
+
+dataset_params:
+ root: './checkpoints/live_speech_portraits/McStay'
+ fit_data_path: './checkpoints/live_speech_portraits/McStay/3d_fit_data.npz'
+ pts3d_path: './checkpoints/live_speech_portraits/McStay/tracked3D_normalized_pts_fix_contour.npy'
+
+
diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/May.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/May.yaml
new file mode 100644
index 00000000..d28c4e76
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/May.yaml
@@ -0,0 +1,32 @@
+model_params:
+ APC:
+ ckp_path: './data/APC_epoch_160.model'
+ mel_dim: 80
+ hidden_size: 512
+ num_layers: 3
+ residual: false
+ use_LLE: 1
+ Knear: 10
+ LLE_percent: 1
+ Audio2Mouth:
+ ckp_path: './data/May/checkpoints/Audio2Feature.pkl'
+ smooth: 1.5
+ AMP: ['XYZ', 2, 2, 2] # method, x, y, z
+ Headpose:
+ ckp_path: './data/May/checkpoints/Audio2Headpose.pkl'
+ sigma: 0.3
+ smooth: [5, 10] # rot, trans
+ AMP: [1, 0.5] # rot, trans
+ shoulder_AMP: 0.5
+ Image2Image:
+ ckp_path: './data/May/checkpoints/Feature2Face.pkl'
+ size: 'large'
+ save_input: 1
+
+
+dataset_params:
+ root: './data/May/'
+ fit_data_path: './data/May/3d_fit_data.npz'
+ pts3d_path: './data/May/tracked3D_normalized_pts_fix_contour.npy'
+
+
diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/McStay.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/McStay.yaml
new file mode 100644
index 00000000..25d9db17
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/McStay.yaml
@@ -0,0 +1,32 @@
+model_params:
+ APC:
+ ckp_path: './data/APC_epoch_160.model'
+ mel_dim: 80
+ hidden_size: 512
+ num_layers: 3
+ residual: false
+ use_LLE: 1
+ Knear: 10
+ LLE_percent: 1
+ Audio2Mouth:
+ ckp_path: './data/McStay/checkpoints/Audio2Feature.pkl'
+ smooth: 2
+ AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z
+ Headpose:
+ ckp_path: './data/McStay/checkpoints/Audio2Headpose.pkl'
+ sigma: 0.3
+ smooth: [5, 10] # rot, trans
+ AMP: [1, 1] # rot, trans
+ shoulder_AMP: 0.5
+ Image2Image:
+ ckp_path: './data/McStay/checkpoints/Feature2Face.pkl'
+ size: 'normal'
+ save_input: 1
+
+
+dataset_params:
+ root: './data/McStay/'
+ fit_data_path: './data/McStay/3d_fit_data.npz'
+ pts3d_path: './data/McStay/tracked3D_normalized_pts_fix_contour.npy'
+
+
diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Nadella.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Nadella.yaml
new file mode 100644
index 00000000..66f33573
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Nadella.yaml
@@ -0,0 +1,32 @@
+model_params:
+ APC:
+ ckp_path: './data/APC_epoch_160.model'
+ mel_dim: 80
+ hidden_size: 512
+ num_layers: 3
+ residual: false
+ use_LLE: 1
+ Knear: 10
+ LLE_percent: 1
+ Audio2Mouth:
+ ckp_path: './data/Nadella/checkpoints/Audio2Feature.pkl'
+ smooth: 2
+ AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z
+ Headpose:
+ ckp_path: './data/Nadella/checkpoints/Audio2Headpose.pkl'
+ sigma: 0.3
+ smooth: [5, 10] # rot, trans
+ AMP: [0.5, 0.5] # rot, trans
+ shoulder_AMP: 0.5
+ Image2Image:
+ ckp_path: './data/Nadella/checkpoints/Feature2Face.pkl'
+ size: 'normal'
+ save_input: 1
+
+
+dataset_params:
+ root: './data/Nadella/'
+ fit_data_path: './data/Nadella/3d_fit_data.npz'
+ pts3d_path: './data/Nadella/tracked3D_normalized_pts_fix_contour.npy'
+
+
diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama1.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama1.yaml
new file mode 100644
index 00000000..ce414876
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama1.yaml
@@ -0,0 +1,32 @@
+model_params:
+ APC:
+ ckp_path: './data/APC_epoch_160.model'
+ mel_dim: 80
+ hidden_size: 512
+ num_layers: 3
+ residual: false
+ use_LLE: 1
+ Knear: 10
+ LLE_percent: 1
+ Audio2Mouth:
+ ckp_path: './data/Obama1/checkpoints/Audio2Feature.pkl'
+ smooth: 1
+ AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z
+ Headpose:
+ ckp_path: './data/Obama1/checkpoints/Audio2Headpose.pkl'
+ sigma: 0.3
+ smooth: [2, 8] # rot, trans
+ AMP: [1, 1] # rot, trans
+ shoulder_AMP: 0.5
+ Image2Image:
+ ckp_path: './data/Obama1/checkpoints/Feature2Face.pkl'
+ size: 'normal'
+ save_input: 1
+
+
+dataset_params:
+ root: './data/Obama1/'
+ fit_data_path: './data/Obama1/3d_fit_data.npz'
+ pts3d_path: './data/Obama1/tracked3D_normalized_pts_fix_contour.npy'
+
+
diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama2.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama2.yaml
new file mode 100644
index 00000000..6d543151
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama2.yaml
@@ -0,0 +1,32 @@
+model_params:
+ APC:
+ ckp_path: './data/APC_epoch_160.model'
+ mel_dim: 80
+ hidden_size: 512
+ num_layers: 3
+ residual: false
+ use_LLE: 1
+ Knear: 10
+ LLE_percent: 1
+ Audio2Mouth:
+ ckp_path: './data/Obama2/checkpoints/Audio2Feature.pkl'
+ smooth: 2
+ AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z
+ Headpose:
+ ckp_path: './data/Obama2/checkpoints/Audio2Headpose.pkl'
+ sigma: 0.3
+ smooth: [3, 10] # rot, trans
+ AMP: [1, 1] # rot, trans
+ shoulder_AMP: 0.5
+ Image2Image:
+ ckp_path: './data/Obama2/checkpoints/Feature2Face.pkl'
+ size: 'normal'
+ save_input: 1
+
+
+dataset_params:
+ root: './data/Obama2/'
+ fit_data_path: './data/Obama2/3d_fit_data.npz'
+ pts3d_path: './data/Obama2/tracked3D_normalized_pts_fix_contour.npy'
+
+
diff --git a/talkingface-toolkit-main/talkingface/properties/model/Wav2Lip.yaml b/talkingface-toolkit-main/talkingface/properties/model/Wav2Lip.yaml
new file mode 100644
index 00000000..50b044cc
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/properties/model/Wav2Lip.yaml
@@ -0,0 +1,55 @@
+# Syncnet
+syncnet_wt: 0.03 # (int) is initially zero, will be set automatically to 0.03 later.Leads to faster convergence.
+syncnet_batch_size: 64 # (int) batch_size for syncnet train
+syncnet_lr: 0.0001 #(float) learning rate for syncnet train
+syncnet_eval_interval: 10000
+syncnet_checkpoint_interval: 10000
+syncnet_T: 5
+syncnet_mel_step_size: 16
+syncnet_checkpoint_path: "checkpoints/wav2lip/lipsync_expert.pth"
+
+# Data preprocessing for Wav2lip
+num_mels: 80
+rescale: True
+rescaling_max: 0.9
+use_lws: False
+n_fft: 800
+hop_size: 200
+win_size: 800
+sample_rate: 16000
+frame_shift_ms: None
+signal_normalization: True
+allow_clipping_in_normalization: True
+symmetric_mels: True
+max_abs_value: 4
+preemphasize: True
+preemphasis: 0.97
+min_level_db: -100
+ref_level_db: 20
+fmin: 55
+fmax: 7600
+img_size: 96
+fps: 25
+mel_step_size: 16
+
+batch_size: 16
+ngpu: 1
+
+
+# Train
+checkpoint_sub_dir: "/wav2lip" # 和overall.yaml里checkpoint_dir拼起来作为最终目录
+
+temp_sub_dir: "/wav2lip" # 和overall.yaml里temp_dir拼起来作为最终目录
+
+
+# Inference
+pads: [0, 10, 0, 0]
+static: False
+face_det_batch_size: 16
+resize_factor: 1
+crop: [0, -1, 0, -1]
+box: [-1, -1, -1, -1]
+rotate: False
+nosmooth: False
+wav2lip_batch_size: 128
+vshift: 15
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/properties/overall.yaml b/talkingface-toolkit-main/talkingface/properties/overall.yaml
new file mode 100644
index 00000000..81ac51ae
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/properties/overall.yaml
@@ -0,0 +1,31 @@
+# Enviroment Settings
+gpu_id: '3, 4, 5' # (str) The id of GPU device(s).
+worker: 0 # (int) The number of workers processing the data.
+use_gpu: True # (bool) Whether or not to use GPU.
+seed: 2023 # (int) Random seed.
+checkpoint_dir: 'saved' # (str) The path to save checkpoint file.
+show_progress: True # (bool) Whether or not to show the progress bar of every epoch.
+log_wandb: False # (bool) Whether or not to use Weights & Biases(W&B).
+shuffle: True # (bool) Whether or not to shuffle the training data before each epoch.
+device: 'cuda'
+reproducibility: True # (bool) Whether or not to make results reproducible.
+
+# Training Settings
+epochs: 300 # (int) The number of training epochs.
+train_batch_size: 2048 # (int) The training batch size.
+learner: adam # (str) The name of used optimizer.
+learning_rate: 0.0001 # (float) Learning rate.
+eval_step: 1 # (int) The number of training epochs before an evaluation on the valid dataset.
+stopping_step: 10 # (int) The threshold for validation-based early stopping.
+weight_decay: 0.0 # (float) The weight decay value (L2 penalty) for optimizers.
+saved: True
+resume: True
+train: True
+
+# Evaluation Settings
+metrics: ["LSE", "SSIM"]
+evaluate_batch_size: 50 # (int) The evaluation batch size.
+lse_checkpoint_path: 'checkpoints/LSE/syncnet_v2.model'
+temp_dir: 'results/temp'
+lse_reference_dir: 'lse'
+valid_metric_bigger: False # (bool) Whether to take a bigger valid metric value as a better result.
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/quick_start/__init__.py b/talkingface-toolkit-main/talkingface/quick_start/__init__.py
new file mode 100644
index 00000000..7dd17091
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/quick_start/__init__.py
@@ -0,0 +1,5 @@
+from talkingface.quick_start.quick_start import (
+ run,
+ run_talkingface,
+)
+from talkingface.quick_start.meta_portrait_base_main import *
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/quick_start/quick_start.py b/talkingface-toolkit-main/talkingface/quick_start/quick_start.py
new file mode 100644
index 00000000..3ff2e889
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/quick_start/quick_start.py
@@ -0,0 +1,105 @@
+import logging
+import sys
+import torch.distributed as dist
+from collections.abc import MutableMapping
+from logging import getLogger
+import os
+from torch.utils import data as data_utils
+from ray import tune
+
+from talkingface.config import Config
+
+from talkingface.utils import (
+ init_logger,
+ get_model,
+ get_trainer,
+ init_seed,
+ set_color,
+ get_flops,
+ get_environment,
+ get_preprocess,
+ create_dataset
+)
+
+def run(
+ model,
+ dataset,
+ config_file_list=None,
+ config_dict=None,
+ saved=True,
+ evaluate_model_file=None
+):
+ res = run_talkingface(
+ model=model,
+ dataset=dataset,
+ config_file_list=config_file_list,
+ config_dict=config_dict,
+ saved=saved,
+ evaluate_model_file=evaluate_model_file,
+ )
+ return res
+
+def run_talkingface(
+ model=None,
+ dataset=None,
+ config_file_list=None,
+ config_dict=None,
+ saved=True,
+ queue=None,
+ evaluate_model_file=None
+):
+ """A fast running api, which include the complete process of training and testing a model on a specified dataset
+ Args:
+ model (str, optional): Model name. Defaults to ``None``.
+ dataset (str, optional): Dataset name. Defaults to ``None``.
+ config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``.
+ config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``.
+ saved (bool, optional): Whether to save the model. Defaults to ``True``.
+ queue (torch.multiprocessing.Queue, optional): The queue used to pass the result to the main process. Defaults to ``None``.
+ """
+
+ config = Config(
+ model=model,
+ dataset=dataset,
+ config_file_list=config_file_list,
+ config_dict=config_dict,
+ )
+ init_seed(config["seed"], config["reproducibility"])
+ init_logger(config)
+ logger = getLogger()
+ logger.info(sys.argv)
+ logger.info(config)
+
+ #data processing
+ # print(not (os.listdir(config['preprocessed_root'])))
+ if config['need_preprocess'] and (not (os.path.exists(config['preprocessed_root'])) or not (os.listdir(config['preprocessed_root']))):
+ get_preprocess(config['dataset'])(config).run()
+
+ train_dataset, val_dataset = create_dataset(config)
+ train_data_loader = data_utils.DataLoader(
+ train_dataset, batch_size=config["batch_size"], shuffle=True
+ )
+ val_data_loader = data_utils.DataLoader(
+ val_dataset, batch_size=config["batch_size"], shuffle=False
+ )
+
+ # load model
+ model = get_model(config["model"])(config).to(config["device"])
+ logger.info(model)
+
+ trainer = get_trainer(config["model"])(config, model)
+
+ # model training
+ if config['train']:
+ trainer.fit(train_data_loader, val_data_loader, saved=saved, show_progress=config["show_progress"])
+ # print(1)
+
+ if not config['train'] and evaluate_model_file is None:
+ print("error: no model file to evaluate without training")
+ return
+ # model evaluating
+ trainer.evaluate(model_file = evaluate_model_file)
+
+
+
+
diff --git a/talkingface-toolkit-main/talkingface/trainer/AD_NeRF_masterTrainer.py b/talkingface-toolkit-main/talkingface/trainer/AD_NeRF_masterTrainer.py
new file mode 100644
index 00000000..51385639
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/trainer/AD_NeRF_masterTrainer.py
@@ -0,0 +1,965 @@
+from load_audface import load_audface_data
+import os
+import sys
+import numpy as np
+import imageio
+import json
+import random
+import time
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm, trange
+from natsort import natsorted
+from run_nerf_helpers import *
+
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+np.random.seed(0)
+DEBUG = False
+
+
+def batchify(fn, chunk):
+ """Constructs a version of 'fn' that applies to smaller batches.
+ """
+ if chunk is None:
+ return fn
+
+ def ret(inputs):
+ return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)
+ return ret
+
+
+def run_network(inputs, viewdirs, aud_para, fn, embed_fn, embeddirs_fn, netchunk=1024*64):
+ """Prepares inputs and applies network 'fn'.
+ """
+ inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])
+ embedded = embed_fn(inputs_flat)
+ aud = aud_para.unsqueeze(0).expand(inputs_flat.shape[0], -1)
+ embedded = torch.cat((embedded, aud), -1)
+ if viewdirs is not None:
+ input_dirs = viewdirs[:, None].expand(inputs.shape)
+ input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])
+ embedded_dirs = embeddirs_fn(input_dirs_flat)
+ embedded = torch.cat([embedded, embedded_dirs], -1)
+
+ outputs_flat = batchify(fn, netchunk)(embedded)
+ outputs = torch.reshape(outputs_flat, list(
+ inputs.shape[:-1]) + [outputs_flat.shape[-1]])
+ return outputs
+
+
+def batchify_rays(rays_flat, bc_rgb, aud_para, chunk=1024*32, **kwargs):
+ """Render rays in smaller minibatches to avoid OOM.
+ """
+ all_ret = {}
+ for i in range(0, rays_flat.shape[0], chunk):
+ ret = render_rays(rays_flat[i:i+chunk], bc_rgb[i:i+chunk],
+ aud_para, **kwargs)
+ for k in ret:
+ if k not in all_ret:
+ all_ret[k] = []
+ all_ret[k].append(ret[k])
+
+ all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}
+ return all_ret
+
+
+def render_dynamic_face(H, W, focal, cx, cy, chunk=1024*32, rays=None, bc_rgb=None, aud_para=None,
+ c2w=None, ndc=True, near=0., far=1.,
+ use_viewdirs=False, c2w_staticcam=None,
+ **kwargs):
+ if c2w is not None:
+ # special case to render full image
+ rays_o, rays_d = get_rays(H, W, focal, c2w, cx, cy)
+ bc_rgb = bc_rgb.reshape(-1, 3)
+ else:
+ # use provided ray batch
+ rays_o, rays_d = rays
+
+ if use_viewdirs:
+ # provide ray directions as input
+ viewdirs = rays_d
+ if c2w_staticcam is not None:
+ # special case to visualize effect of viewdirs
+ rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam, cx, cy)
+ viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
+ viewdirs = torch.reshape(viewdirs, [-1, 3]).float()
+
+ sh = rays_d.shape # [..., 3]
+ if ndc:
+ # for forward facing scenes
+ rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)
+
+ # Create ray batch
+ rays_o = torch.reshape(rays_o, [-1, 3]).float()
+ rays_d = torch.reshape(rays_d, [-1, 3]).float()
+
+ near, far = near * \
+ torch.ones_like(rays_d[..., :1]), far * \
+ torch.ones_like(rays_d[..., :1])
+ rays = torch.cat([rays_o, rays_d, near, far], -1)
+ if use_viewdirs:
+ rays = torch.cat([rays, viewdirs], -1)
+
+ # Render and reshape
+ all_ret = batchify_rays(rays, bc_rgb, aud_para, chunk, **kwargs)
+ for k in all_ret:
+ k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
+ all_ret[k] = torch.reshape(all_ret[k], k_sh)
+
+ k_extract = ['rgb_map', 'disp_map', 'acc_map', 'last_weight']
+ ret_list = [all_ret[k] for k in k_extract]
+ ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract}
+ return ret_list + [ret_dict]
+
+
+def render(H, W, focal, cx, cy, chunk=1024*32, rays=None, c2w=None, ndc=True,
+ near=0., far=1.,
+ use_viewdirs=False, c2w_staticcam=None,
+ **kwargs):
+ """Render rays
+ Args:
+ H: int. Height of image in pixels.
+ W: int. Width of image in pixels.
+ focal: float. Focal length of pinhole camera.
+ chunk: int. Maximum number of rays to process simultaneously. Used to
+ control maximum memory usage. Does not affect final results.
+ rays: array of shape [2, batch_size, 3]. Ray origin and direction for
+ each example in batch.
+ c2w: array of shape [3, 4]. Camera-to-world transformation matrix.
+ ndc: bool. If True, represent ray origin, direction in NDC coordinates.
+ near: float or array of shape [batch_size]. Nearest distance for a ray.
+ far: float or array of shape [batch_size]. Farthest distance for a ray.
+ use_viewdirs: bool. If True, use viewing direction of a point in space in model.
+ c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
+ camera while using other c2w argument for viewing directions.
+ Returns:
+ rgb_map: [batch_size, 3]. Predicted RGB values for rays.
+ disp_map: [batch_size]. Disparity map. Inverse of depth.
+ acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.
+ extras: dict with everything returned by render_rays().
+ """
+ if c2w is not None:
+ # special case to render full image
+ rays_o, rays_d = get_rays(H, W, focal, c2w, cx, cy)
+ else:
+ # use provided ray batch
+ rays_o, rays_d = rays
+
+ if use_viewdirs:
+ # provide ray directions as input
+ viewdirs = rays_d
+ if c2w_staticcam is not None:
+ # special case to visualize effect of viewdirs
+ rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam, cx, cy)
+ viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)
+ viewdirs = torch.reshape(viewdirs, [-1, 3]).float()
+
+ sh = rays_d.shape # [..., 3]
+ if ndc:
+ # for forward facing scenes
+ rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)
+
+ # Create ray batch
+ rays_o = torch.reshape(rays_o, [-1, 3]).float()
+ rays_d = torch.reshape(rays_d, [-1, 3]).float()
+
+ near, far = near * \
+ torch.ones_like(rays_d[..., :1]), far * \
+ torch.ones_like(rays_d[..., :1])
+ rays = torch.cat([rays_o, rays_d, near, far], -1)
+ if use_viewdirs:
+ rays = torch.cat([rays, viewdirs], -1)
+
+ # Render and reshape
+ all_ret = batchify_rays(rays, chunk, **kwargs)
+ for k in all_ret:
+ k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
+ all_ret[k] = torch.reshape(all_ret[k], k_sh)
+
+ k_extract = ['rgb_map', 'disp_map', 'acc_map']
+ ret_list = [all_ret[k] for k in k_extract]
+ ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract}
+ return ret_list + [ret_dict]
+
+
+def render_path(render_poses, aud_paras, bc_img, hwfcxy,
+ chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0):
+
+ H, W, focal, cx, cy = hwfcxy
+
+ if render_factor != 0:
+ # Render downsampled for speed
+ H = H//render_factor
+ W = W//render_factor
+ focal = focal/render_factor
+
+ rgbs = []
+ disps = []
+ last_weights = []
+
+ t = time.time()
+ for i, c2w in enumerate(tqdm(render_poses)):
+ print(i, time.time() - t)
+ t = time.time()
+ rgb, disp, acc, last_weight, _ = render_dynamic_face(
+ H, W, focal, cx, cy, chunk=chunk, c2w=c2w[:3,
+ :4], aud_para=aud_paras[i], bc_rgb=bc_img,
+ **render_kwargs)
+ rgbs.append(rgb.cpu().numpy())
+ disps.append(disp.cpu().numpy())
+ last_weights.append(last_weight.cpu().numpy())
+ if i == 0:
+ print(rgb.shape, disp.shape)
+
+ """
+ if gt_imgs is not None and render_factor==0:
+ p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i])))
+ print(p)
+ """
+
+ if savedir is not None:
+ rgb8 = to8b(rgbs[-1])
+ filename = os.path.join(savedir, '{:03d}.png'.format(i))
+ imageio.imwrite(filename, rgb8)
+
+ rgbs = np.stack(rgbs, 0)
+ disps = np.stack(disps, 0)
+ last_weights = np.stack(last_weights, 0)
+
+ return rgbs, disps, last_weights
+
+
+def create_nerf(args):
+ """Instantiate NeRF's MLP model.
+ """
+ embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
+
+ input_ch_views = 0
+ embeddirs_fn = None
+ if args.use_viewdirs:
+ embeddirs_fn, input_ch_views = get_embedder(
+ args.multires_views, args.i_embed)
+ output_ch = 5 if args.N_importance > 0 else 4
+ skips = [4]
+ model = FaceNeRF(D=args.netdepth, W=args.netwidth,
+ input_ch=input_ch, dim_aud=args.dim_aud,
+ output_ch=output_ch, skips=skips,
+ input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
+ grad_vars = list(model.parameters())
+
+ model_fine = None
+ if args.N_importance > 0:
+ model_fine = FaceNeRF(D=args.netdepth_fine, W=args.netwidth_fine,
+ input_ch=input_ch, dim_aud=args.dim_aud,
+ output_ch=output_ch, skips=skips,
+ input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device)
+ grad_vars += list(model_fine.parameters())
+
+ def network_query_fn(inputs, viewdirs, aud_para, network_fn): \
+ return run_network(inputs, viewdirs, aud_para, network_fn,
+ embed_fn=embed_fn, embeddirs_fn=embeddirs_fn, netchunk=args.netchunk)
+
+ # Create optimizer
+ optimizer = torch.optim.Adam(
+ params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
+
+ start = 0
+ basedir = args.basedir
+ expname = args.expname
+
+ ##########################
+
+ # Load checkpoints
+ if args.ft_path is not None and args.ft_path != 'None':
+ ckpts = [args.ft_path]
+ else:
+ ckpts = [os.path.join(basedir, expname, f) for f in natsorted(
+ os.listdir(os.path.join(basedir, expname))) if 'tar' in f]
+
+ print('Found ckpts', ckpts)
+ learned_codes_dict = None
+ AudNet_state = None
+ AudAttNet_state = None
+ optimizer_aud_state = None
+ optimizer_audatt_state = None
+ if len(ckpts) > 0 and not args.no_reload:
+ ckpt_path = ckpts[-1]
+ print('Reloading from', ckpt_path)
+ ckpt = torch.load(ckpt_path)
+
+ start = ckpt['global_step']
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
+ AudNet_state = ckpt['network_audnet_state_dict']
+ optimizer_aud_state = ckpt['optimizer_aud_state_dict']
+
+ # Load model
+ model.load_state_dict(ckpt['network_fn_state_dict'])
+ if model_fine is not None:
+ model_fine.load_state_dict(ckpt['network_fine_state_dict'])
+ if 'network_audattnet_state_dict' in ckpt:
+ AudAttNet_state = ckpt['network_audattnet_state_dict']
+ if 'optimize_audatt_state_dict' in ckpt:
+ optimizer_audatt_state = ckpt['optimize_audatt_state_dict']
+
+ ##########################
+
+ render_kwargs_train = {
+ 'network_query_fn': network_query_fn,
+ 'perturb': args.perturb,
+ 'N_importance': args.N_importance,
+ 'network_fine': model_fine,
+ 'N_samples': args.N_samples,
+ 'network_fn': model,
+ 'use_viewdirs': args.use_viewdirs,
+ 'white_bkgd': args.white_bkgd,
+ 'raw_noise_std': args.raw_noise_std,
+ }
+
+ # NDC only good for LLFF-style forward facing data
+ if args.dataset_type != 'llff' or args.no_ndc:
+ print('Not ndc!')
+ render_kwargs_train['ndc'] = False
+ render_kwargs_train['lindisp'] = args.lindisp
+
+ render_kwargs_test = {
+ k: render_kwargs_train[k] for k in render_kwargs_train}
+ render_kwargs_test['perturb'] = False
+ render_kwargs_test['raw_noise_std'] = 0.
+
+ return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, learned_codes_dict, \
+ AudNet_state, optimizer_aud_state, AudAttNet_state, optimizer_audatt_state
+
+
+def raw2outputs(raw, z_vals, rays_d, bc_rgb, raw_noise_std=0, white_bkgd=False, pytest=False):
+ """Transforms model's predictions to semantically meaningful values.
+ Args:
+ raw: [num_rays, num_samples along ray, 4]. Prediction from model.
+ z_vals: [num_rays, num_samples along ray]. Integration time.
+ rays_d: [num_rays, 3]. Direction of each ray.
+ Returns:
+ rgb_map: [num_rays, 3]. Estimated RGB color of a ray.
+ disp_map: [num_rays]. Disparity map. Inverse of depth map.
+ acc_map: [num_rays]. Sum of weights along each ray.
+ weights: [num_rays, num_samples]. Weights assigned to each sampled color.
+ depth_map: [num_rays]. Estimated distance to object.
+ """
+ def raw2alpha(raw, dists, act_fn=F.relu): return 1. - \
+ torch.exp(-(act_fn(raw)+1e-6)*dists)
+
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
+ dists = torch.cat([dists, torch.Tensor([1e10]).expand(
+ dists[..., :1].shape)], -1) # [N_rays, N_samples]
+
+ dists = dists * torch.norm(rays_d[..., None, :], dim=-1)
+
+ rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3]
+ rgb = torch.cat((rgb[:, :-1, :], bc_rgb.unsqueeze(1)), dim=1)
+ noise = 0.
+ if raw_noise_std > 0.:
+ noise = torch.randn(raw[..., 3].shape) * raw_noise_std
+
+ # Overwrite randomly sampled data if pytest
+ if pytest:
+ np.random.seed(0)
+ noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std
+ noise = torch.Tensor(noise)
+
+ alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples]
+ # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True)
+ weights = alpha * \
+ torch.cumprod(
+ torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1]
+ rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3]
+
+ depth_map = torch.sum(weights * z_vals, -1)
+ disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map),
+ depth_map / torch.sum(weights, -1))
+ acc_map = torch.sum(weights, -1)
+
+ if white_bkgd:
+ rgb_map = rgb_map + (1.-acc_map[..., None])
+
+ return rgb_map, disp_map, acc_map, weights, depth_map
+
+
+def render_rays(ray_batch,
+ bc_rgb,
+ aud_para,
+ network_fn,
+ network_query_fn,
+ N_samples,
+ retraw=False,
+ lindisp=False,
+ perturb=0.,
+ N_importance=0,
+ network_fine=None,
+ white_bkgd=False,
+ raw_noise_std=0.,
+ verbose=False,
+ pytest=False):
+ """Volumetric rendering.
+ Args:
+ ray_batch: array of shape [batch_size, ...]. All information necessary
+ for sampling along a ray, including: ray origin, ray direction, min
+ dist, max dist, and unit-magnitude viewing direction.
+ network_fn: function. Model for predicting RGB and density at each point
+ in space.
+ network_query_fn: function used for passing queries to network_fn.
+ N_samples: int. Number of different times to sample along each ray.
+ retraw: bool. If True, include model's raw, unprocessed predictions.
+ lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
+ perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
+ random points in time.
+ N_importance: int. Number of additional times to sample along each ray.
+ These samples are only passed to network_fine.
+ network_fine: "fine" network with same spec as network_fn.
+ white_bkgd: bool. If True, assume a white background.
+ raw_noise_std: ...
+ verbose: bool. If True, print more debugging info.
+ Returns:
+ rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
+ disp_map: [num_rays]. Disparity map. 1 / depth.
+ acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
+ raw: [num_rays, num_samples, 4]. Raw predictions from model.
+ rgb0: See rgb_map. Output for coarse model.
+ disp0: See disp_map. Output for coarse model.
+ acc0: See acc_map. Output for coarse model.
+ z_std: [num_rays]. Standard deviation of distances along ray for each
+ sample.
+ """
+ N_rays = ray_batch.shape[0]
+ rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each
+ viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None
+ bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2])
+ near, far = bounds[..., 0], bounds[..., 1] # [-1,1]
+
+ t_vals = torch.linspace(0., 1., steps=N_samples)
+ if not lindisp:
+ z_vals = near * (1.-t_vals) + far * (t_vals)
+ else:
+ z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
+
+ z_vals = z_vals.expand([N_rays, N_samples])
+
+ if perturb > 0.:
+ # get intervals between samples
+ mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
+ upper = torch.cat([mids, z_vals[..., -1:]], -1)
+ lower = torch.cat([z_vals[..., :1], mids], -1)
+ # stratified samples in those intervals
+ t_rand = torch.rand(z_vals.shape)
+
+ # Pytest, overwrite u with numpy's fixed random numbers
+ if pytest:
+ np.random.seed(0)
+ t_rand = np.random.rand(*list(z_vals.shape))
+ t_rand = torch.Tensor(t_rand)
+ t_rand[..., -1] = 1.0
+ z_vals = lower + (upper - lower) * t_rand
+ pts = rays_o[..., None, :] + rays_d[..., None, :] * \
+ z_vals[..., :, None] # [N_rays, N_samples, 3]
+ raw = network_query_fn(pts, viewdirs, aud_para, network_fn)
+ rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
+ raw, z_vals, rays_d, bc_rgb, raw_noise_std, white_bkgd, pytest=pytest)
+
+ if N_importance > 0:
+
+ rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
+
+ z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
+ z_samples = sample_pdf(
+ z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.), pytest=pytest)
+ z_samples = z_samples.detach()
+
+ z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
+ pts = rays_o[..., None, :] + rays_d[..., None, :] * \
+ z_vals[..., :, None] # [N_rays, N_samples + N_importance, 3]
+
+ run_fn = network_fn if network_fine is None else network_fine
+ raw = network_query_fn(pts, viewdirs, aud_para, run_fn)
+
+ rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(
+ raw, z_vals, rays_d, bc_rgb, raw_noise_std, white_bkgd, pytest=pytest)
+
+ ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map}
+ if retraw:
+ ret['raw'] = raw
+ if N_importance > 0:
+ ret['rgb0'] = rgb_map_0
+ ret['disp0'] = disp_map_0
+ ret['acc0'] = acc_map_0
+ ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
+ ret['last_weight'] = weights[..., -1]
+
+ for k in ret:
+ if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
+ print(f"! [Numerical Error] {k} contains nan or inf.")
+
+ return ret
+
+
+def config_parser():
+
+ import configargparse
+ parser = configargparse.ArgumentParser()
+ parser.add_argument('--config', is_config_file=True,
+ help='config file path')
+ parser.add_argument("--expname", type=str,
+ help='experiment name')
+ parser.add_argument("--basedir", type=str, default='./logs/',
+ help='where to store ckpts and logs')
+ parser.add_argument("--datadir", type=str, default='./data/llff/fern',
+ help='input data directory')
+
+ # training options
+ parser.add_argument("--netdepth", type=int, default=8,
+ help='layers in network')
+ parser.add_argument("--netwidth", type=int, default=256,
+ help='channels per layer')
+ parser.add_argument("--netdepth_fine", type=int, default=8,
+ help='layers in fine network')
+ parser.add_argument("--netwidth_fine", type=int, default=256,
+ help='channels per layer in fine network')
+ parser.add_argument("--N_rand", type=int, default=2048,
+ help='batch size (number of random rays per gradient step)')
+ parser.add_argument("--lrate", type=float, default=5e-4,
+ help='learning rate')
+ parser.add_argument("--lrate_decay", type=int, default=250,
+ help='exponential learning rate decay (in 1000 steps)')
+ parser.add_argument("--chunk", type=int, default=1024,
+ help='number of rays processed in parallel, decrease if running out of memory')
+ parser.add_argument("--netchunk", type=int, default=1024*64,
+ help='number of pts sent through network in parallel, decrease if running out of memory')
+ parser.add_argument("--no_batching", action='store_false',
+ help='only take random rays from 1 image at a time')
+ parser.add_argument("--no_reload", action='store_true',
+ help='do not reload weights from saved ckpt')
+ parser.add_argument("--ft_path", type=str, default=None,
+ help='specific weights npy file to reload for coarse network')
+ parser.add_argument("--N_iters", type=int, default=400000,
+ help='number of iterations')
+
+ # rendering options
+ parser.add_argument("--N_samples", type=int, default=64,
+ help='number of coarse samples per ray')
+ parser.add_argument("--N_importance", type=int, default=128,
+ help='number of additional fine samples per ray')
+ parser.add_argument("--perturb", type=float, default=1.,
+ help='set to 0. for no jitter, 1. for jitter')
+ parser.add_argument("--use_viewdirs", action='store_false',
+ help='use full 5D input instead of 3D')
+ parser.add_argument("--i_embed", type=int, default=0,
+ help='set 0 for default positional encoding, -1 for none')
+ parser.add_argument("--multires", type=int, default=10,
+ help='log2 of max freq for positional encoding (3D location)')
+ parser.add_argument("--multires_views", type=int, default=4,
+ help='log2 of max freq for positional encoding (2D direction)')
+ parser.add_argument("--raw_noise_std", type=float, default=0.,
+ help='std dev of noise added to regularize sigma_a output, 1e0 recommended')
+
+ parser.add_argument("--render_only", action='store_true',
+ help='do not optimize, reload weights and render out render_poses path')
+ parser.add_argument("--render_test", action='store_true',
+ help='render the test set instead of render_poses path')
+ parser.add_argument("--render_factor", type=int, default=0,
+ help='downsampling factor to speed up rendering, set 4 or 8 for fast preview')
+
+ # training options
+ parser.add_argument("--precrop_iters", type=int, default=0,
+ help='number of steps to train on central crops')
+ parser.add_argument("--precrop_frac", type=float,
+ default=.5, help='fraction of img taken for central crops')
+
+ # dataset options
+ parser.add_argument("--dataset_type", type=str, default='audface',
+ help='options: llff / blender / deepvoxels')
+ parser.add_argument("--testskip", type=int, default=8,
+ help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels')
+
+ # deepvoxels flags
+ parser.add_argument("--shape", type=str, default='greek',
+ help='options : armchair / cube / greek / vase')
+
+ # blender flags
+ parser.add_argument("--white_bkgd", action='store_false',
+ help='set to render synthetic data on a white bkgd (always use for dvoxels)')
+ parser.add_argument("--half_res", action='store_true',
+ help='load blender synthetic data at 400x400 instead of 800x800')
+
+ # face flags
+ parser.add_argument("--with_test", type=int, default=0,
+ help='whether to use test set')
+ parser.add_argument("--dim_aud", type=int, default=64,
+ help='dimension of audio features for NeRF')
+ parser.add_argument("--sample_rate", type=float, default=0.95,
+ help="sample rate in a bounding box")
+ parser.add_argument("--near", type=float, default=0.3,
+ help="near sampling plane")
+ parser.add_argument("--far", type=float, default=0.9,
+ help="far sampling plane")
+ parser.add_argument("--test_file", type=str, default='transforms_test.json',
+ help='test file')
+ parser.add_argument("--aud_file", type=str, default='aud.npy',
+ help='test audio deepspeech file')
+ parser.add_argument("--win_size", type=int, default=16,
+ help="windows size of audio feature")
+ parser.add_argument("--smo_size", type=int, default=8,
+ help="window size for smoothing audio features")
+ parser.add_argument('--nosmo_iters', type=int, default=200000,
+ help='number of iterations befor applying smoothing on audio features')
+
+ # llff flags
+ parser.add_argument("--factor", type=int, default=8,
+ help='downsample factor for LLFF images')
+ parser.add_argument("--no_ndc", action='store_true',
+ help='do not use normalized device coordinates (set for non-forward facing scenes)')
+ parser.add_argument("--lindisp", action='store_true',
+ help='sampling linearly in disparity rather than depth')
+ parser.add_argument("--spherify", action='store_true',
+ help='set for spherical 360 scenes')
+ parser.add_argument("--llffhold", type=int, default=8,
+ help='will take every 1/N images as LLFF test set, paper uses 8')
+
+ # logging/saving options
+ parser.add_argument("--i_print", type=int, default=100,
+ help='frequency of console printout and metric loggin')
+ parser.add_argument("--i_img", type=int, default=500,
+ help='frequency of tensorboard image logging')
+ parser.add_argument("--i_weights", type=int, default=10000,
+ help='frequency of weight ckpt saving')
+ parser.add_argument("--i_testset", type=int, default=10000,
+ help='frequency of testset saving')
+ parser.add_argument("--i_video", type=int, default=50000,
+ help='frequency of render_poses video saving')
+
+ return parser
+
+
+def train():
+
+ parser = config_parser()
+ args = parser.parse_args()
+
+ # Load data
+
+ if args.dataset_type == 'audface':
+ if args.with_test == 1:
+ poses, auds, bc_img, hwfcxy = \
+ load_audface_data(args.datadir, args.testskip,
+ args.test_file, args.aud_file)
+ images = np.zeros(1)
+ else:
+ images, poses, auds, bc_img, hwfcxy, sample_rects, mouth_rects, i_split = load_audface_data(
+ args.datadir, args.testskip)
+ print('Loaded audface', images.shape, hwfcxy, args.datadir)
+ if args.with_test == 0:
+ i_train, i_val = i_split
+
+ near = args.near
+ far = args.far
+ else:
+ print('Unknown dataset type', args.dataset_type, 'exiting')
+ return
+
+ # Cast intrinsics to right types
+ H, W, focal, cx, cy = hwfcxy
+ H, W = int(H), int(W)
+ hwf = [H, W, focal]
+ hwfcxy = [H, W, focal, cx, cy]
+
+ # if args.render_test:
+ # render_poses = np.array(poses[i_test])
+
+ # Create log dir and copy the config file
+ basedir = args.basedir
+ expname = args.expname
+ os.makedirs(os.path.join(basedir, expname), exist_ok=True)
+ f = os.path.join(basedir, expname, 'args.txt')
+ with open(f, 'w') as file:
+ for arg in sorted(vars(args)):
+ attr = getattr(args, arg)
+ file.write('{} = {}\n'.format(arg, attr))
+ if args.config is not None:
+ f = os.path.join(basedir, expname, 'config.txt')
+ with open(f, 'w') as file:
+ file.write(open(args.config, 'r').read())
+
+ # Create nerf model
+ render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, \
+ learned_codes, AudNet_state, optimizer_aud_state, AudAttNet_state, optimizer_audatt_state \
+ = create_nerf(args)
+ global_step = start
+
+ AudNet = AudioNet(args.dim_aud, args.win_size).to(device)
+ AudAttNet = AudioAttNet().to(device)
+ optimizer_Aud = torch.optim.Adam(
+ params=list(AudNet.parameters()), lr=args.lrate, betas=(0.9, 0.999))
+ optimizer_AudAtt = torch.optim.Adam(
+ params=list(AudAttNet.parameters()), lr=args.lrate, betas=(0.9, 0.999))
+
+ if AudNet_state is not None:
+ AudNet.load_state_dict(AudNet_state, strict=False)
+ if optimizer_aud_state is not None:
+ optimizer_Aud.load_state_dict(optimizer_aud_state)
+ if AudAttNet_state is not None:
+ AudAttNet.load_state_dict(AudAttNet_state, strict=False)
+ if optimizer_audatt_state is not None:
+ optimizer_AudAtt.load_state_dict(optimizer_audatt_state)
+ bds_dict = {
+ 'near': near,
+ 'far': far,
+ }
+ render_kwargs_train.update(bds_dict)
+ render_kwargs_test.update(bds_dict)
+
+ # Move training data to GPU
+ bc_img = torch.Tensor(bc_img).to(device).float()/255.0
+ poses = torch.Tensor(poses).to(device).float()
+ auds = torch.Tensor(auds).to(device).float()
+
+ if args.render_only:
+ print('RENDER ONLY')
+ with torch.no_grad():
+ # Default is smoother render_poses path
+ images = None
+ testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format(
+ 'test' if args.render_test else 'path', start))
+ os.makedirs(testsavedir, exist_ok=True)
+ print('test poses shape', poses.shape)
+ auds_val = AudNet(auds)
+ rgbs, disp, last_weight = render_path(poses, auds_val, bc_img, hwfcxy, args.chunk, render_kwargs_test,
+ gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor)
+ np.save(os.path.join(testsavedir, 'last_weight.npy'), last_weight)
+ print('Done rendering', testsavedir)
+ imageio.mimwrite(os.path.join(
+ testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8)
+ return
+
+ num_frames = images.shape[0]
+
+
+ # Prepare raybatch tensor if batching random rays
+ N_rand = args.N_rand
+ print('N_rand', N_rand, 'no_batching',
+ args.no_batching, 'sample_rate', args.sample_rate)
+ use_batching = not args.no_batching
+
+ if use_batching:
+ # For random ray batching
+ print('get rays')
+ rays = np.stack([get_rays_np(H, W, focal, p, cx, cy)
+ for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3]
+ print('done, concats')
+ # [N, ro+rd+rgb, H, W, 3]
+ rays_rgb = np.concatenate([rays, images[:, None]], 1)
+ # [N, H, W, ro+rd+rgb, 3]
+ rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4])
+ rays_rgb = np.stack([rays_rgb[i]
+ for i in i_train], 0) # train images only
+ # [(N-1)*H*W, ro+rd+rgb, 3]
+ rays_rgb = np.reshape(rays_rgb, [-1, 3, 3])
+ rays_rgb = rays_rgb.astype(np.float32)
+ print('shuffle rays')
+ np.random.shuffle(rays_rgb)
+
+ print('done')
+ i_batch = 0
+
+ if use_batching:
+ rays_rgb = torch.Tensor(rays_rgb).to(device)
+
+ N_iters = args.N_iters + 1
+ print('Begin')
+ print('TRAIN views are', i_train)
+ print('VAL views are', i_val)
+
+ start = start + 1
+ for i in trange(start, N_iters):
+ time0 = time.time()
+
+ # Sample random ray batch
+ if use_batching:
+ # Random over all images
+ batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
+ batch = torch.transpose(batch, 0, 1)
+ batch_rays, target_s = batch[:2], batch[2]
+
+ i_batch += N_rand
+ if i_batch >= rays_rgb.shape[0]:
+ print("Shuffle data after an epoch!")
+ rand_idx = torch.randperm(rays_rgb.shape[0])
+ rays_rgb = rays_rgb[rand_idx]
+ i_batch = 0
+
+ else:
+ # Random from one image
+ img_i = np.random.choice(i_train)
+ target = torch.as_tensor(imageio.imread(
+ images[img_i])).to(device).float()/255.0
+ pose = poses[img_i, :3, :4]
+ rect = sample_rects[img_i]
+ mouth_rect = mouth_rects[img_i]
+ aud = auds[img_i]
+ if global_step >= args.nosmo_iters:
+ smo_half_win = int(args.smo_size / 2)
+ left_i = img_i - smo_half_win
+ right_i = img_i + smo_half_win
+ pad_left, pad_right = 0, 0
+ if left_i < 0:
+ pad_left = -left_i
+ left_i = 0
+ if right_i > i_train.shape[0]:
+ pad_right = right_i-i_train.shape[0]
+ right_i = i_train.shape[0]
+ auds_win = auds[left_i:right_i]
+ if pad_left > 0:
+ auds_win = torch.cat(
+ (torch.zeros_like(auds_win)[:pad_left], auds_win), dim=0)
+ if pad_right > 0:
+ auds_win = torch.cat(
+ (auds_win, torch.zeros_like(auds_win)[:pad_right]), dim=0)
+ auds_win = AudNet(auds_win)
+ aud = auds_win[smo_half_win]
+ aud_smo = AudAttNet(auds_win)
+ else:
+ aud = AudNet(aud.unsqueeze(0))
+ if N_rand is not None:
+ rays_o, rays_d = get_rays(
+ H, W, focal, torch.Tensor(pose), cx, cy) # (H, W, 3), (H, W, 3)
+
+ if i < args.precrop_iters:
+ dH = int(H//2 * args.precrop_frac)
+ dW = int(W//2 * args.precrop_frac)
+ coords = torch.stack(
+ torch.meshgrid(
+ torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH),
+ torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW)
+ ), -1)
+ if i == start:
+ print(
+ f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}")
+ else:
+ coords = torch.stack(torch.meshgrid(torch.linspace(
+ 0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2)
+
+ coords = torch.reshape(coords, [-1, 2]) # (H * W, 2)
+ if args.sample_rate > 0:
+ rect_inds = (coords[:, 0] >= rect[0]) & (
+ coords[:, 0] <= rect[0] + rect[2]) & (
+ coords[:, 1] >= rect[1]) & (
+ coords[:, 1] <= rect[1] + rect[3])
+ coords_rect = coords[rect_inds]
+ coords_norect = coords[~rect_inds]
+ rect_num = int(N_rand*args.sample_rate)
+ norect_num = N_rand - rect_num
+ select_inds_rect = np.random.choice(
+ coords_rect.shape[0], size=[rect_num], replace=False) # (N_rand,)
+ # (N_rand, 2)
+ select_coords_rect = coords_rect[select_inds_rect].long()
+ select_inds_norect = np.random.choice(
+ coords_norect.shape[0], size=[norect_num], replace=False) # (N_rand,)
+ # (N_rand, 2)
+ select_coords_norect = coords_norect[select_inds_norect].long(
+ )
+ select_coords = torch.cat(
+ (select_coords_rect, select_coords_norect), dim=0)
+ else:
+ select_inds = np.random.choice(
+ coords.shape[0], size=[N_rand], replace=False) # (N_rand,)
+ select_coords = coords[select_inds].long()
+
+ rays_o = rays_o[select_coords[:, 0],
+ select_coords[:, 1]] # (N_rand, 3)
+ rays_d = rays_d[select_coords[:, 0],
+ select_coords[:, 1]] # (N_rand, 3)
+ batch_rays = torch.stack([rays_o, rays_d], 0)
+ target_s = target[select_coords[:, 0],
+ select_coords[:, 1]] # (N_rand, 3)
+ bc_rgb = bc_img[select_coords[:, 0],
+ select_coords[:, 1]]
+
+ ##### Core optimization loop #####
+ if global_step >= args.nosmo_iters:
+ rgb, disp, acc, _, extras = render_dynamic_face(H, W, focal, cx, cy, chunk=args.chunk, rays=batch_rays,
+ aud_para=aud_smo, bc_rgb=bc_rgb,
+ verbose=i < 10, retraw=True,
+ **render_kwargs_train)
+ else:
+ rgb, disp, acc, _, extras = render_dynamic_face(H, W, focal, cx, cy, chunk=args.chunk, rays=batch_rays,
+ aud_para=aud, bc_rgb=bc_rgb,
+ verbose=i < 10, retraw=True,
+ **render_kwargs_train)
+
+ optimizer.zero_grad()
+ optimizer_Aud.zero_grad()
+ optimizer_AudAtt.zero_grad()
+ img_loss = img2mse(rgb, target_s)
+ trans = extras['raw'][..., -1]
+ loss = img_loss
+ psnr = mse2psnr(img_loss)
+
+ if 'rgb0' in extras:
+ img_loss0 = img2mse(extras['rgb0'], target_s)
+ loss = loss + img_loss0
+ psnr0 = mse2psnr(img_loss0)
+
+ loss.backward()
+ optimizer.step()
+ optimizer_Aud.step()
+ if global_step >= args.nosmo_iters:
+ optimizer_AudAtt.step()
+ # NOTE: IMPORTANT!
+ ### update learning rate ###
+ decay_rate = 0.1
+ decay_steps = args.lrate_decay * 1000
+ new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = new_lrate
+
+ for param_group in optimizer_Aud.param_groups:
+ param_group['lr'] = new_lrate
+
+ for param_group in optimizer_AudAtt.param_groups:
+ param_group['lr'] = new_lrate*5
+ ################################
+
+ dt = time.time()-time0
+
+ # Rest is logging
+ if i % args.i_weights == 0:
+ path = os.path.join(basedir, expname, '{:06d}_head.tar'.format(i))
+ torch.save({
+ 'global_step': global_step,
+ 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),
+ 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),
+ 'network_audnet_state_dict': AudNet.state_dict(),
+ 'optimizer_state_dict': optimizer.state_dict(),
+ 'optimizer_aud_state_dict': optimizer_Aud.state_dict(),
+ 'network_audattnet_state_dict': AudAttNet.state_dict(),
+ 'optimizer_audatt_state_dict': optimizer_AudAtt.state_dict(),
+ }, path)
+ print('Saved checkpoints at', path)
+
+ if i % args.i_testset == 0 and i > 0:
+ testsavedir = os.path.join(
+ basedir, expname, 'testset_{:06d}'.format(i))
+ os.makedirs(testsavedir, exist_ok=True)
+ print('test poses shape', poses[i_val].shape)
+ auds_val = AudNet(auds[i_val])
+ with torch.no_grad():
+ render_path(torch.Tensor(poses[i_val]).to(
+ device), auds_val, bc_img, hwfcxy, args.chunk, render_kwargs_test, gt_imgs=None, savedir=testsavedir)
+ print('Saved test set')
+
+ if i % args.i_print == 0:
+ tqdm.write(
+ f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}")
+ global_step += 1
+
+
+if __name__ == '__main__':
+ torch.set_default_tensor_type('torch.cuda.FloatTensor')
+
+ train()
diff --git a/talkingface-toolkit-main/talkingface/trainer/__init__.py b/talkingface-toolkit-main/talkingface/trainer/__init__.py
new file mode 100644
index 00000000..219a0873
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/trainer/__init__.py
@@ -0,0 +1,2 @@
+from talkingface.trainer.trainer import *
+from talkingface.trainer.meta_portrait_base_train_ddp import *
diff --git a/talkingface-toolkit-main/talkingface/trainer/trainer.py b/talkingface-toolkit-main/talkingface/trainer/trainer.py
new file mode 100644
index 00000000..2c34717b
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/trainer/trainer.py
@@ -0,0 +1,557 @@
+import os
+
+from logging import getLogger
+from time import time
+import dlib, json, subprocess
+import torch.nn.functional as F
+import glob
+import numpy as np
+import torch
+import torch.optim as optim
+from torch.nn.utils.clip_grad import clip_grad_norm_
+from tqdm import tqdm
+import torch.cuda.amp as amp
+from torch import nn
+from pathlib import Path
+
+from talkingface.utils import(
+ ensure_dir,
+ get_local_time,
+ early_stopping,
+ calculate_valid_score,
+ dict2str,
+ get_tensorboard,
+ set_color,
+ get_gpu_usage,
+ WandbLogger
+)
+from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio
+from talkingface.evaluator import Evaluator
+
+
+class AbstractTrainer(object):
+ r"""Trainer Class is used to manage the training and evaluation processes of recommender system models.
+ AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according
+ to different training and evaluation strategies.
+ """
+
+ def __init__(self, config, model):
+ self.config = config
+ self.model = model
+
+ def fit(self, train_data):
+ r"""Train the model based on the train data."""
+ raise NotImplementedError("Method [next] should be implemented.")
+
+ def evaluate(self, eval_data):
+ r"""Evaluate the model based on the eval data."""
+
+ raise NotImplementedError("Method [next] should be implemented.")
+
+
+class Trainer(AbstractTrainer):
+ r"""The basic Trainer for basic training and evaluation strategies in talkingface systems. This class defines common
+ functions for training and evaluation processes of most recommender system models, including fit(), evaluate(),
+ resume_checkpoint() and some other features helpful for model training and evaluation.
+
+ Generally speaking, this class can serve most recommender system models, If the training process of the model is to
+ simply optimize a single loss without involving any complex training strategies.
+
+ Initializing the Trainer needs two parameters: `config` and `model`. `config` records the parameters information
+ for controlling training and evaluation, such as `learning_rate`, `epochs`, `eval_step` and so on.
+ `model` is the instantiated object of a Model Class.
+
+ """
+ def __init__(self, config, model):
+ super(Trainer, self).__init__(config, model)
+ self.logger = getLogger()
+ self.tensorboard = get_tensorboard(self.logger)
+ self.wandblogger = WandbLogger(config)
+ # self.enable_amp = config["enable_amp"]
+ # self.enable_scaler = torch.cuda.is_available() and config["enable_scaler"]
+
+ # config for train
+ self.learner = config["learner"]
+ self.learning_rate = config["learning_rate"]
+ self.epochs = config["epochs"]
+ self.eval_step = min(config["eval_step"], self.epochs)
+ self.stopping_step = config["stopping_step"]
+ self.test_batch_size = config["eval_batch_size"]
+ self.gpu_available = torch.cuda.is_available() and config["use_gpu"]
+ self.device = config["device"]
+ self.checkpoint_dir = config["checkpoint_dir"]
+ ensure_dir(self.checkpoint_dir)
+ saved_model_file = "{}-{}.pth".format(self.config["model"], get_local_time())
+ self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file)
+ self.weight_decay = config["weight_decay"]
+ self.start_epoch = 0
+ self.cur_step = 0
+ self.train_loss_dict = dict()
+ self.optimizer = self._build_optimizer()
+ self.evaluator = Evaluator(config)
+
+ self.valid_metric_bigger = config["valid_metric_bigger"]
+ self.best_valid_score = -np.inf if self.valid_metric_bigger else np.inf
+ self.best_valid_result = None
+
+ def _build_optimizer(self, **kwargs):
+ params = kwargs.pop("params", self.model.parameters())
+ learner = kwargs.pop("learner", self.learner)
+ learning_rate = kwargs.pop("learning_rate", self.learning_rate)
+ weight_decay = kwargs.pop("weight_decay", self.weight_decay)
+ if (self.config["reg_weight"] and weight_decay and weight_decay * self.config["reg_weight"] > 0):
+ self.logger.warning(
+ "The parameters [weight_decay] and [reg_weight] are specified simultaneously, "
+ "which may lead to double regularization."
+ )
+
+ if learner.lower() == "adam":
+ optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay)
+ elif learner.lower() == "adamw":
+ optimizer = optim.AdamW(params, lr=learning_rate, weight_decay=weight_decay)
+ elif learner.lower() == "sgd":
+ optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay)
+ elif learner.lower() == "adagrad":
+ optimizer = optim.Adagrad(
+ params, lr=learning_rate, weight_decay=weight_decay
+ )
+ elif learner.lower() == "rmsprop":
+ optimizer = optim.RMSprop(
+ params, lr=learning_rate, weight_decay=weight_decay
+ )
+ elif learner.lower() == "sparse_adam":
+ optimizer = optim.SparseAdam(params, lr=learning_rate)
+ if weight_decay > 0:
+ self.logger.warning(
+ "Sparse Adam cannot argument received argument [{weight_decay}]"
+ )
+ else:
+ self.logger.warning(
+ "Received unrecognized optimizer, set default Adam optimizer"
+ )
+ optimizer = optim.Adam(params, lr=learning_rate)
+ return optimizer
+
+ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
+ r"""Train the model in an epoch
+
+ Args:
+ train_data (DataLoader): The train data.
+ epoch_idx (int): The current epoch id.
+ loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be
+ :attr:`self.model.calculate_loss`. Defaults to ``None``.
+ show_progress (bool): Show the progress of training epoch. Defaults to ``False``.
+
+ Returns:
+ the averaged loss of this epoch
+ """
+ self.model.train()
+ loss_func = loss_func or self.model.calculate_loss
+ total_loss_dict = {}
+ step = 0
+ iter_data = (
+ tqdm(
+ train_data,
+ total=len(train_data),
+ ncols=None,
+ )
+ if show_progress
+ else train_data
+ )
+
+ for batch_idx, interaction in enumerate(iter_data):
+ self.optimizer.zero_grad()
+ step += 1
+ losses_dict = loss_func(interaction)
+ loss = losses_dict["loss"]
+
+ for key, value in losses_dict.items():
+ if key in total_loss_dict:
+ if not torch.is_tensor(value):
+ total_loss_dict[key] += value
+ # 如果键已经在总和字典中,累加当前值
+ else:
+ losses_dict[key] = value.item()
+ total_loss_dict[key] += value.item()
+ else:
+ if not torch.is_tensor(value):
+ total_loss_dict[key] = value
+ # 否则,将当前值添加到字典中
+ else:
+ losses_dict[key] = value.item()
+ total_loss_dict[key] = value.item()
+ iter_data.set_description(set_color(f"train {epoch_idx} {losses_dict}", "pink"))
+
+ self._check_nan(loss)
+ loss.backward()
+ self.optimizer.step()
+ average_loss_dict = {}
+ for key, value in total_loss_dict.items():
+ average_loss_dict[key] = value/step
+
+ return average_loss_dict
+
+
+
+ def _valid_epoch(self, valid_data, show_progress=False):
+ r"""Valid the model with valid data. Different from the evaluate, this is use for training.
+
+ Args:
+ valid_data (DataLoader): the valid data.
+ show_progress (bool): Show the progress of evaluate epoch. Defaults to ``False``.
+
+ Returns:
+ loss
+ """
+ print('Valid for {} steps'.format(self.eval_steps))
+ self.model.eval()
+ total_loss_dict = {}
+ iter_data = (
+ tqdm(valid_data,
+ total=len(valid_data),
+ ncols=None,
+ )
+ if show_progress
+ else valid_data
+ )
+ step = 0
+ for batch_idx, batched_data in enumerate(iter_data):
+ step += 1
+ batched_data.to(self.device)
+ losses_dict = self.model.calculate_loss(batched_data, valid=True)
+ for key, value in losses_dict.items():
+ if key in total_loss_dict:
+ if not torch.is_tensor(value):
+ total_loss_dict[key] += value
+ # 如果键已经在总和字典中,累加当前值
+ else:
+ losses_dict[key] = value.item()
+ total_loss_dict[key] += value.item()
+ else:
+ if not torch.is_tensor(value):
+ total_loss_dict[key] = value
+ # 否则,将当前值添加到字典中
+ else:
+ losses_dict[key] = value.item()
+ total_loss_dict[key] = value.item()
+ iter_data.set_description(set_color(f"Valid {losses_dict}", "pink"))
+ average_loss_dict = {}
+ for key, value in total_loss_dict.items():
+ average_loss_dict[key] = value/step
+
+ return average_loss_dict
+
+
+
+
+ def _save_checkpoint(self, epoch, verbose=True, **kwargs):
+ r"""Store the model parameters information and training information.
+
+ Args:
+ epoch (int): the current epoch id
+
+ """
+ saved_model_file = kwargs.pop("saved_model_file", self.saved_model_file)
+ state = {
+ "config": self.config,
+ "epoch": epoch,
+ "cur_step": self.cur_step,
+ "best_valid_score": self.best_valid_score,
+ "state_dict": self.model.state_dict(),
+ "other_parameter": self.model.other_parameter(),
+ "optimizer": self.optimizer.state_dict(),
+ }
+ torch.save(state, saved_model_file, pickle_protocol=4)
+ if verbose:
+ self.logger.info(
+ set_color("Saving current", "blue") + f": {saved_model_file}"
+ )
+
+ def resume_checkpoint(self, resume_file):
+ r"""Load the model parameters information and training information.
+
+ Args:
+ resume_file (file): the checkpoint file
+
+ """
+ resume_file = str(resume_file)
+ self.saved_model_file = resume_file
+ checkpoint = torch.load(resume_file, map_location=self.device)
+ self.start_epoch = checkpoint["epoch"] + 1
+ self.cur_step = checkpoint["cur_step"]
+ # self.best_valid_score = checkpoint["best_valid_score"]
+
+ # load architecture params from checkpoint
+ if checkpoint["config"]["model"].lower() != self.config["model"].lower():
+ self.logger.warning(
+ "Architecture configuration given in config file is different from that of checkpoint. "
+ "This may yield an exception while state_dict is being loaded."
+ )
+ self.model.load_state_dict(checkpoint["state_dict"])
+ self.model.load_other_parameter(checkpoint.get("other_parameter"))
+
+ # load optimizer state from checkpoint only when optimizer type is not changed
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+ message_output = "Checkpoint loaded. Resume training from epoch {}".format(
+ self.start_epoch
+ )
+ self.logger.info(message_output)
+
+ def _check_nan(self, loss):
+ if torch.isnan(loss):
+ raise ValueError("Training loss is nan")
+
+ def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses):
+ des = self.config["loss_decimal_place"] or 4
+ train_loss_output = (
+ set_color(f"epoch {epoch_idx} training", "green")
+ + " ["
+ + set_color("time", "blue")
+ + f": {e_time - s_time:.2f}s, "
+ )
+ # 遍历字典,格式化并添加每个损失项
+ loss_items = [
+ set_color(f"{key}", "blue") + f": {value:.{des}f}"
+ for key, value in losses.items()
+ ]
+ # 将所有损失项连接成一个字符串,并与前面的输出拼接
+ train_loss_output += ", ".join(loss_items)
+ return train_loss_output + "]"
+
+ def _add_hparam_to_tensorboard(self, best_valid_result):
+ # base hparam
+ hparam_dict = {
+ "learner": self.config["learner"],
+ "learning_rate": self.config["learning_rate"],
+ "train_batch_size": self.config["train_batch_size"],
+ }
+ # unrecorded parameter
+ unrecorded_parameter = {
+ parameter
+ for parameters in self.config.parameters.values()
+ for parameter in parameters
+ }.union({"model", "dataset", "config_files", "device"})
+ # other model-specific hparam
+ hparam_dict.update(
+ {
+ para: val
+ for para, val in self.config.final_config_dict.items()
+ if para not in unrecorded_parameter
+ }
+ )
+ for k in hparam_dict:
+ if hparam_dict[k] is not None and not isinstance(
+ hparam_dict[k], (bool, str, float, int)
+ ):
+ hparam_dict[k] = str(hparam_dict[k])
+
+ self.tensorboard.add_hparams(
+ hparam_dict, {"hparam/best_valid_result": best_valid_result}
+ )
+
+ def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None):
+ r"""Train the model based on the train data and the valid data.
+
+ Args:
+ train_data (DataLoader): the train data
+ valid_data (DataLoader, optional): the valid data, default: None.
+ If it's None, the early_stopping is invalid.
+ verbose (bool, optional): whether to write training and evaluation information to logger, default: True
+ saved (bool, optional): whether to save the model parameters, default: True
+ show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``.
+ callback_fn (callable): Optional callback function executed at end of epoch.
+ Includes (epoch_idx, valid_score) input arguments.
+
+ Returns:
+ best result
+ """
+ if saved and self.start_epoch >= self.epochs:
+ self._save_checkpoint(-1, verbose=verbose)
+
+ if not (self.config['resume_checkpoint_path'] == None ) and self.config['resume']:
+ self.resume_checkpoint(self.config['resume_checkpoint_path'])
+
+ for epoch_idx in range(self.start_epoch, self.epochs):
+ training_start_time = time()
+ train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress)
+ self.train_loss_dict[epoch_idx] = (
+ sum(train_loss) if isinstance(train_loss, tuple) else train_loss
+ )
+ training_end_time = time()
+ train_loss_output = self._generate_train_loss_output(
+ epoch_idx, training_start_time, training_end_time, train_loss)
+
+ if verbose:
+ self.logger.info(train_loss_output)
+ # self._add_train_loss_to_tensorboard(epoch_idx, train_loss)
+
+ if self.eval_step <= 0 or not valid_data:
+ if saved:
+ self._save_checkpoint(epoch_idx, verbose=verbose)
+ continue
+
+ if (epoch_idx + 1) % self.eval_step == 0:
+ valid_start_time = time()
+ valid_loss = self._valid_epoch(valid_data=valid_data, show_progress=show_progress)
+
+ (self.best_valid_score, self.cur_step, stop_flag,update_flag,) = early_stopping(
+ valid_loss['loss'],
+ self.best_valid_score,
+ self.cur_step,
+ max_step=self.stopping_step,
+ bigger=self.valid_metric_bigger,
+ )
+ valid_end_time = time()
+
+ valid_loss_output = (
+ set_color("valid result", "blue") + ": \n" + dict2str(valid_loss)
+ )
+ if verbose:
+ self.logger.info(valid_loss_output)
+
+
+ if update_flag:
+ if saved:
+ self._save_checkpoint(epoch_idx, verbose=verbose)
+ self.best_valid_result = valid_loss['loss']
+
+ if stop_flag:
+ stop_output = "Finished training, best eval result in epoch %d" % (
+ epoch_idx - self.cur_step * self.eval_step
+ )
+ if verbose:
+ self.logger.info(stop_output)
+ break
+ @torch.no_grad()
+ def evaluate(self, load_best_model=True, model_file=None):
+ """
+ Evaluate the model based on the test data.
+
+ args: load_best_model: bool, whether to load the best model in the training process.
+ model_file: str, the model file you want to evaluate.
+
+ """
+ if load_best_model:
+ checkpoint_file = model_file or self.saved_model_file
+ checkpoint = torch.load(checkpoint_file, map_location=self.device)
+ self.model.load_state_dict(checkpoint["state_dict"])
+ self.model.load_other_parameter(checkpoint.get("other_parameter"))
+ message_output = "Loading model structure and parameters from {}".format(
+ checkpoint_file
+ )
+ self.logger.info(message_output)
+ self.model.eval()
+
+ datadict = self.model.generate_batch()
+ eval_result = self.evaluator.evaluate(datadict)
+ self.logger.info(eval_result)
+
+
+
+class Wav2LipTrainer(Trainer):
+ def __init__(self, config, model):
+ super(Wav2LipTrainer, self).__init__(config, model)
+
+ def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False):
+ r"""Train the model in an epoch
+
+ Args:
+ train_data (DataLoader): The train data.
+ epoch_idx (int): The current epoch id.
+ loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be
+ :attr:`self.model.calculate_loss`. Defaults to ``None``.
+ show_progress (bool): Show the progress of training epoch. Defaults to ``False``.
+
+ Returns:
+ the averaged loss of this epoch
+ """
+ self.model.train()
+
+
+
+ loss_func = loss_func or self.model.calculate_loss
+ total_loss_dict = {}
+ step = 0
+ iter_data = (
+ tqdm(
+ train_data,
+ total=len(train_data),
+ ncols=None,
+ )
+ if show_progress
+ else train_data
+ )
+
+ for batch_idx, interaction in enumerate(iter_data):
+ self.optimizer.zero_grad()
+ step += 1
+ losses_dict = loss_func(interaction)
+ loss = losses_dict["loss"]
+
+ for key, value in losses_dict.items():
+ if key in total_loss_dict:
+ if not torch.is_tensor(value):
+ total_loss_dict[key] += value
+ # 如果键已经在总和字典中,累加当前值
+ else:
+ losses_dict[key] = value.item()
+ total_loss_dict[key] += value.item()
+ else:
+ if not torch.is_tensor(value):
+ total_loss_dict[key] = value
+ # 否则,将当前值添加到字典中
+ else:
+ losses_dict[key] = value.item()
+ total_loss_dict[key] = value.item()
+ iter_data.set_description(set_color(f"train {epoch_idx} {losses_dict}", "pink"))
+
+ self._check_nan(loss)
+ loss.backward()
+ self.optimizer.step()
+ average_loss_dict = {}
+ for key, value in total_loss_dict.items():
+ average_loss_dict[key] = value/step
+
+ return average_loss_dict
+
+
+
+ def _valid_epoch(self, valid_data, loss_func=None, show_progress=False):
+ print('Valid'.format(self.eval_step))
+ self.model.eval()
+ total_loss_dict = {}
+ iter_data = (
+ tqdm(valid_data,
+ total=len(valid_data),
+ ncols=None,
+ desc=set_color("Valid", "pink")
+ )
+ if show_progress
+ else valid_data
+ )
+ step = 0
+ for batch_idx, batched_data in enumerate(iter_data):
+ step += 1
+ losses_dict = self.model.calculate_loss(batched_data, valid=True)
+ for key, value in losses_dict.items():
+ if key in total_loss_dict:
+ if not torch.is_tensor(value):
+ total_loss_dict[key] += value
+ # 如果键已经在总和字典中,累加当前值
+ else:
+ losses_dict[key] = value.item()
+ total_loss_dict[key] += value.item()
+ else:
+ if not torch.is_tensor(value):
+ total_loss_dict[key] = value
+ # 否则,将当前值添加到字典中
+ else:
+ losses_dict[key] = value.item()
+ total_loss_dict[key] = value.item()
+ average_loss_dict = {}
+ for key, value in total_loss_dict.items():
+ average_loss_dict[key] = value/step
+ if losses_dict["sync_loss"] < .75:
+ self.model.config["syncnet_wt"] = 0.01
+ return average_loss_dict
+
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/utils/__init__.py b/talkingface-toolkit-main/talkingface/utils/__init__.py
new file mode 100644
index 00000000..712ba862
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/__init__.py
@@ -0,0 +1,43 @@
+from talkingface.utils.logger import init_logger, set_color
+from talkingface.utils.enum_type import *
+from talkingface.utils.utils import (
+ get_local_time,
+ ensure_dir,
+ get_model,
+ get_trainer,
+ get_environment,
+ early_stopping,
+ calculate_valid_score,
+ dict2str,
+ init_seed,
+ get_tensorboard,
+ get_gpu_usage,
+ get_flops,
+ list_to_latex,
+ get_preprocess,
+ create_dataset
+)
+
+from talkingface.utils.argument_list import *
+from talkingface.utils.wandblogger import WandbLogger
+__all__ = [
+ "init_logger",
+ "get_local_time",
+ "ensure_dir",
+ "get_model",
+ "get_trainer",
+ "early_stopping",
+ "calculate_valid_score",
+ "dict2str",
+ "init_seed",
+ "general_arguments",
+ "training_arguments",
+ "evaluation_arguments",
+ "get_tensorboard",
+ "set_color",
+ "get_gpu_usage",
+ "get_flops",
+ "get_environment",
+ "list_to_latex",
+ "WandbLogger",
+]
diff --git a/talkingface-toolkit-main/talkingface/utils/argument_list.py b/talkingface-toolkit-main/talkingface/utils/argument_list.py
new file mode 100644
index 00000000..01a7dad3
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/argument_list.py
@@ -0,0 +1,48 @@
+general_arguments = [
+ 'gpu_id', 'use_gpu',
+ 'seed',
+ 'reproducibility',
+ 'state',
+ 'checkpoint_dir',
+ 'show_progress',
+ 'config_file',
+ 'log_wandb',
+]
+
+training_arguments = [
+ 'epochs', 'train_batch_size',
+ 'learner', 'learning_rate',
+ 'eval_step', 'stopping_step',
+ 'weight_decay', 'resume'
+ 'train'
+]
+
+evaluation_arguments = [
+ 'metrics',
+ 'temp_dir',
+ 'evaluate_batch_size',
+ 'lse_checkpoint_path',
+ 'valid_metric_bigger',
+]
+# evaluation_arguments = [
+# 'eval_args', 'repeatable',
+# 'metrics', 'topk', 'valid_metric', 'valid_metric_bigger',
+# 'eval_batch_size',
+# 'metric_decimal_place',
+# ]
+
+# dataset_arguments = [
+# 'field_separator', 'seq_separator',
+# 'USER_ID_FIELD', 'ITEM_ID_FIELD', 'RATING_FIELD', 'TIME_FIELD',
+# 'seq_len',
+# 'LABEL_FIELD', 'threshold',
+# 'NEG_PREFIX',
+# 'ITEM_LIST_LENGTH_FIELD', 'LIST_SUFFIX', 'MAX_ITEM_LIST_LENGTH', 'POSITION_FIELD',
+# 'HEAD_ENTITY_ID_FIELD', 'TAIL_ENTITY_ID_FIELD', 'RELATION_ID_FIELD', 'ENTITY_ID_FIELD',
+# 'load_col', 'unload_col', 'unused_col', 'additional_feat_suffix',
+# 'rm_dup_inter', 'val_interval', 'filter_inter_by_user_or_item',
+# 'user_inter_num_interval', 'item_inter_num_interval',
+# 'alias_of_user_id', 'alias_of_item_id', 'alias_of_entity_id', 'alias_of_relation_id',
+# 'preload_weight', 'normalize_field', 'normalize_all',
+# 'benchmark_filename',
+# ]
diff --git a/talkingface-toolkit-main/talkingface/utils/data_process.py b/talkingface-toolkit-main/talkingface/utils/data_process.py
new file mode 100644
index 00000000..cbc430ac
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/data_process.py
@@ -0,0 +1,95 @@
+import os
+import cv2
+import numpy as np
+import subprocess
+from tqdm import tqdm
+from glob import glob
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from talkingface.utils import face_detection
+import traceback
+import librosa
+import librosa.filters
+from scipy import signal
+from scipy.io import wavfile
+
+
+class lrs2Preprocess:
+ def __init__(self, config):
+ self.config = config
+ self.fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False,
+ device=f'cuda:{id}') for id in range(config['ngpu'])]
+ self.template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'
+
+ def process_video_file(self, vfile, gpu_id):
+ video_stream = cv2.VideoCapture(vfile)
+
+ frames = []
+ while 1:
+ still_reading, frame = video_stream.read()
+ if not still_reading:
+ video_stream.release()
+ break
+ frames.append(frame)
+
+ vidname = os.path.basename(vfile).split('.')[0]
+ dirname = vfile.split('/')[-2]
+
+ fulldir = os.path.join(self.config['preprocessed_root'], dirname, vidname)
+ os.makedirs(fulldir, exist_ok=True)
+
+ batches = [frames[i:i + self.config['preprocess_batch_size']] for i in range(0, len(frames), self.config['preprocess_batch_size'])]
+
+ i = -1
+ for fb in batches:
+ preds = self.fa[gpu_id].get_detections_for_batch(np.asarray(fb))
+
+ for j, f in enumerate(preds):
+ i += 1
+ if f is None:
+ continue
+
+ x1, y1, x2, y2 = f
+ cv2.imwrite(os.path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2])
+
+
+ def process_audio_file(self, vfile):
+ vidname = os.path.basename(vfile).split('.')[0]
+ dirname = vfile.split('/')[-2]
+
+ fulldir = os.path.join(self.config['preprocessed_root'], dirname, vidname)
+ os.makedirs(fulldir, exist_ok=True)
+
+ wavpath = os.path.join(fulldir, 'audio.wav')
+
+ command =self.template.format(vfile, wavpath)
+ subprocess.call(command, shell=True)
+
+ def mp_handler(self, job):
+ vfile, gpu_id = job
+ try:
+ self.process_video_file(vfile, gpu_id)
+ except KeyboardInterrupt:
+ exit(0)
+ except:
+ traceback.print_exc()
+
+ def run(self):
+ print(f'Started processing for {self.config["data_root"]} with {self.config["ngpu"]} GPUs')
+
+ filelist = glob(os.path.join(self.config["data_root"], '*/*.mp4'))
+
+ # jobs = [(vfile, i % self.config["ngpu"]) for i, vfile in enumerate(filelist)]
+ # with ThreadPoolExecutor(self.config["ngpu"]) as p:
+ # futures = [p.submit(self.mp_handler, j) for j in jobs]
+ # _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))]
+
+ print('Dumping audios...')
+ for vfile in tqdm(filelist):
+ try:
+ self.process_audio_file(vfile)
+ except KeyboardInterrupt:
+ exit(0)
+ except:
+ traceback.print_exc()
+ continue
+
diff --git a/talkingface-toolkit-main/talkingface/utils/enum_type.py b/talkingface-toolkit-main/talkingface/utils/enum_type.py
new file mode 100644
index 00000000..b701b38b
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/enum_type.py
@@ -0,0 +1,13 @@
+from enum import Enum
+
+class EvaluatorType(Enum):
+ """Type for evaluation metrics.
+
+ - ``SYNC``: SYNC metrics like LSE-C, LSE-D, etc.
+ - ``VIDEOQ``: Video quality metrics like FID, etc.
+ - ``AUDIOQ``: Audio quality metrics like PESQ, etc.
+ """
+
+ SYNC = 1
+ VIDEOQ = 2
+ AUDIOQ = 3
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/README.md b/talkingface-toolkit-main/talkingface/utils/face_detection/README.md
new file mode 100644
index 00000000..c073376e
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/README.md
@@ -0,0 +1 @@
+The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/__init__.py b/talkingface-toolkit-main/talkingface/utils/face_detection/__init__.py
new file mode 100644
index 00000000..4bae29fd
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/__init__.py
@@ -0,0 +1,7 @@
+# -*- coding: utf-8 -*-
+
+__author__ = """Adrian Bulat"""
+__email__ = 'adrian.bulat@nottingham.ac.uk'
+__version__ = '1.0.1'
+
+from .api import FaceAlignment, LandmarksType, NetworkSize
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/api.py b/talkingface-toolkit-main/talkingface/utils/face_detection/api.py
new file mode 100644
index 00000000..dc4dd4f2
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/api.py
@@ -0,0 +1,79 @@
+from __future__ import print_function
+import os
+import torch
+from torch.utils.model_zoo import load_url
+from enum import Enum
+import numpy as np
+import cv2
+try:
+ import urllib.request as request_file
+except BaseException:
+ import urllib as request_file
+
+# from .models import FAN, ResNetDepth
+# from .utils import *
+
+
+class LandmarksType(Enum):
+ """Enum class defining the type of landmarks to detect.
+
+ ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
+ ``_2halfD`` - this points represent the projection of the 3D points into 3D
+ ``_3D`` - detect the points ``(x,y,z)``` in a 3D space
+
+ """
+ _2D = 1
+ _2halfD = 2
+ _3D = 3
+
+
+class NetworkSize(Enum):
+ # TINY = 1
+ # SMALL = 2
+ # MEDIUM = 3
+ LARGE = 4
+
+ def __new__(cls, value):
+ member = object.__new__(cls)
+ member._value_ = value
+ return member
+
+ def __int__(self):
+ return self.value
+
+ROOT = os.path.dirname(os.path.abspath(__file__))
+
+class FaceAlignment:
+ def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
+ device='cuda', flip_input=False, face_detector='sfd', verbose=False):
+ self.device = device
+ self.flip_input = flip_input
+ self.landmarks_type = landmarks_type
+ self.verbose = verbose
+
+ network_size = int(network_size)
+
+ if 'cuda' in device:
+ torch.backends.cudnn.benchmark = True
+
+ # Get the face detector
+ face_detector_module = __import__('talkingface.utils.face_detection.detection.' + face_detector,
+ globals(), locals(), [face_detector], 0)
+ self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
+
+ def get_detections_for_batch(self, images):
+ images = images[..., ::-1]
+ detected_faces = self.face_detector.detect_from_batch(images.copy())
+ results = []
+
+ for i, d in enumerate(detected_faces):
+ if len(d) == 0:
+ results.append(None)
+ continue
+ d = d[0]
+ d = np.clip(d, 0, None)
+
+ x1, y1, x2, y2 = map(int, d[:-1])
+ results.append((x1, y1, x2, y2))
+
+ return results
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/__init__.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/__init__.py
new file mode 100644
index 00000000..1a6b0402
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/__init__.py
@@ -0,0 +1 @@
+from .core import FaceDetector
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/core.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/core.py
new file mode 100644
index 00000000..0f8275e8
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/core.py
@@ -0,0 +1,130 @@
+import logging
+import glob
+from tqdm import tqdm
+import numpy as np
+import torch
+import cv2
+
+
+class FaceDetector(object):
+ """An abstract class representing a face detector.
+
+ Any other face detection implementation must subclass it. All subclasses
+ must implement ``detect_from_image``, that return a list of detected
+ bounding boxes. Optionally, for speed considerations detect from path is
+ recommended.
+ """
+
+ def __init__(self, device, verbose):
+ self.device = device
+ self.verbose = verbose
+
+ if verbose:
+ if 'cpu' in device:
+ logger = logging.getLogger(__name__)
+ logger.warning("Detection running on CPU, this may be potentially slow.")
+
+ if 'cpu' not in device and 'cuda' not in device:
+ if verbose:
+ logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
+ raise ValueError
+
+ def detect_from_image(self, tensor_or_path):
+ """Detects faces in a given image.
+
+ This function detects the faces present in a provided BGR(usually)
+ image. The input can be either the image itself or the path to it.
+
+ Arguments:
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
+ to an image or the image itself.
+
+ Example::
+
+ >>> path_to_image = 'data/image_01.jpg'
+ ... detected_faces = detect_from_image(path_to_image)
+ [A list of bounding boxes (x1, y1, x2, y2)]
+ >>> image = cv2.imread(path_to_image)
+ ... detected_faces = detect_from_image(image)
+ [A list of bounding boxes (x1, y1, x2, y2)]
+
+ """
+ raise NotImplementedError
+
+ def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
+ """Detects faces from all the images present in a given directory.
+
+ Arguments:
+ path {string} -- a string containing a path that points to the folder containing the images
+
+ Keyword Arguments:
+ extensions {list} -- list of string containing the extensions to be
+ consider in the following format: ``.extension_name`` (default:
+ {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
+ folder recursively (default: {False}) show_progress_bar {bool} --
+ display a progressbar (default: {True})
+
+ Example:
+ >>> directory = 'data'
+ ... detected_faces = detect_from_directory(directory)
+ {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
+
+ """
+ if self.verbose:
+ logger = logging.getLogger(__name__)
+
+ if len(extensions) == 0:
+ if self.verbose:
+ logger.error("Expected at list one extension, but none was received.")
+ raise ValueError
+
+ if self.verbose:
+ logger.info("Constructing the list of images.")
+ additional_pattern = '/**/*' if recursive else '/*'
+ files = []
+ for extension in extensions:
+ files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
+
+ if self.verbose:
+ logger.info("Finished searching for images. %s images found", len(files))
+ logger.info("Preparing to run the detection.")
+
+ predictions = {}
+ for image_path in tqdm(files, disable=not show_progress_bar):
+ if self.verbose:
+ logger.info("Running the face detector on image: %s", image_path)
+ predictions[image_path] = self.detect_from_image(image_path)
+
+ if self.verbose:
+ logger.info("The detector was successfully run on all %s images", len(files))
+
+ return predictions
+
+ @property
+ def reference_scale(self):
+ raise NotImplementedError
+
+ @property
+ def reference_x_shift(self):
+ raise NotImplementedError
+
+ @property
+ def reference_y_shift(self):
+ raise NotImplementedError
+
+ @staticmethod
+ def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
+ """Convert path (represented as a string) or torch.tensor to a numpy.ndarray
+
+ Arguments:
+ tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
+ """
+ if isinstance(tensor_or_path, str):
+ return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
+ elif torch.is_tensor(tensor_or_path):
+ # Call cpu in case its coming from cuda
+ return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
+ elif isinstance(tensor_or_path, np.ndarray):
+ return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
+ else:
+ raise TypeError
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/__init__.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/__init__.py
new file mode 100644
index 00000000..5a63ecd4
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/__init__.py
@@ -0,0 +1 @@
+from .sfd_detector import SFDDetector as FaceDetector
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/bbox.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/bbox.py
new file mode 100644
index 00000000..4bd7222e
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/bbox.py
@@ -0,0 +1,129 @@
+from __future__ import print_function
+import os
+import sys
+import cv2
+import random
+import datetime
+import time
+import math
+import argparse
+import numpy as np
+import torch
+
+try:
+ from iou import IOU
+except BaseException:
+ # IOU cython speedup 10x
+ def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
+ sa = abs((ax2 - ax1) * (ay2 - ay1))
+ sb = abs((bx2 - bx1) * (by2 - by1))
+ x1, y1 = max(ax1, bx1), max(ay1, by1)
+ x2, y2 = min(ax2, bx2), min(ay2, by2)
+ w = x2 - x1
+ h = y2 - y1
+ if w < 0 or h < 0:
+ return 0.0
+ else:
+ return 1.0 * w * h / (sa + sb - w * h)
+
+
+def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
+ xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
+ dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
+ dw, dh = math.log(ww / aww), math.log(hh / ahh)
+ return dx, dy, dw, dh
+
+
+def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
+ xc, yc = dx * aww + axc, dy * ahh + ayc
+ ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
+ x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
+ return x1, y1, x2, y2
+
+
+def nms(dets, thresh):
+ if 0 == len(dets):
+ return []
+ x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ keep = []
+ while order.size > 0:
+ i = order[0]
+ keep.append(i)
+ xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
+ xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
+
+ w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
+ ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
+
+ inds = np.where(ovr <= thresh)[0]
+ order = order[inds + 1]
+
+ return keep
+
+
+def encode(matched, priors, variances):
+ """Encode the variances from the priorbox layers into the ground truth boxes
+ we have matched (based on jaccard overlap) with the prior boxes.
+ Args:
+ matched: (tensor) Coords of ground truth for each prior in point-form
+ Shape: [num_priors, 4].
+ priors: (tensor) Prior boxes in center-offset form
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ encoded boxes (tensor), Shape: [num_priors, 4]
+ """
+
+ # dist b/t match center and prior's center
+ g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
+ # encode variance
+ g_cxcy /= (variances[0] * priors[:, 2:])
+ # match wh / prior wh
+ g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
+ g_wh = torch.log(g_wh) / variances[1]
+ # return target for smooth_l1_loss
+ return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
+
+
+def decode(loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ loc (tensor): location predictions for loc layers,
+ Shape: [num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+
+ boxes = torch.cat((
+ priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
+ priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
+ boxes[:, :2] -= boxes[:, 2:] / 2
+ boxes[:, 2:] += boxes[:, :2]
+ return boxes
+
+def batch_decode(loc, priors, variances):
+ """Decode locations from predictions using priors to undo
+ the encoding we did for offset regression at train time.
+ Args:
+ loc (tensor): location predictions for loc layers,
+ Shape: [num_priors,4]
+ priors (tensor): Prior boxes in center-offset form.
+ Shape: [num_priors,4].
+ variances: (list[float]) Variances of priorboxes
+ Return:
+ decoded bounding box predictions
+ """
+
+ boxes = torch.cat((
+ priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
+ priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
+ boxes[:, :, :2] -= boxes[:, :, 2:] / 2
+ boxes[:, :, 2:] += boxes[:, :, :2]
+ return boxes
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/detect.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/detect.py
new file mode 100644
index 00000000..efef6273
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/detect.py
@@ -0,0 +1,112 @@
+import torch
+import torch.nn.functional as F
+
+import os
+import sys
+import cv2
+import random
+import datetime
+import math
+import argparse
+import numpy as np
+
+import scipy.io as sio
+import zipfile
+from .net_s3fd import s3fd
+from .bbox import *
+
+
+def detect(net, img, device):
+ img = img - np.array([104, 117, 123])
+ img = img.transpose(2, 0, 1)
+ img = img.reshape((1,) + img.shape)
+
+ if 'cuda' in device:
+ torch.backends.cudnn.benchmark = True
+
+ img = torch.from_numpy(img).float().to(device)
+ BB, CC, HH, WW = img.size()
+ with torch.no_grad():
+ olist = net(img)
+
+ bboxlist = []
+ for i in range(len(olist) // 2):
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
+ olist = [oelem.data.cpu() for oelem in olist]
+ for i in range(len(olist) // 2):
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
+ FB, FC, FH, FW = ocls.size() # feature map size
+ stride = 2**(i + 2) # 4,8,16,32,64,128
+ anchor = stride * 4
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
+ for Iindex, hindex, windex in poss:
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
+ score = ocls[0, 1, hindex, windex]
+ loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
+ variances = [0.1, 0.2]
+ box = decode(loc, priors, variances)
+ x1, y1, x2, y2 = box[0] * 1.0
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
+ bboxlist.append([x1, y1, x2, y2, score])
+ bboxlist = np.array(bboxlist)
+ if 0 == len(bboxlist):
+ bboxlist = np.zeros((1, 5))
+
+ return bboxlist
+
+def batch_detect(net, imgs, device):
+ imgs = imgs - np.array([104, 117, 123])
+ imgs = imgs.transpose(0, 3, 1, 2)
+
+ if 'cuda' in device:
+ torch.backends.cudnn.benchmark = True
+
+ imgs = torch.from_numpy(imgs).float().to(device)
+ BB, CC, HH, WW = imgs.size()
+ with torch.no_grad():
+ olist = net(imgs)
+
+ bboxlist = []
+ for i in range(len(olist) // 2):
+ olist[i * 2] = F.softmax(olist[i * 2], dim=1)
+ olist = [oelem.data.cpu() for oelem in olist]
+ for i in range(len(olist) // 2):
+ ocls, oreg = olist[i * 2], olist[i * 2 + 1]
+ FB, FC, FH, FW = ocls.size() # feature map size
+ stride = 2**(i + 2) # 4,8,16,32,64,128
+ anchor = stride * 4
+ poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
+ for Iindex, hindex, windex in poss:
+ axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
+ score = ocls[:, 1, hindex, windex]
+ loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
+ priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
+ variances = [0.1, 0.2]
+ box = batch_decode(loc, priors, variances)
+ box = box[:, 0] * 1.0
+ # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
+ bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
+ bboxlist = np.array(bboxlist)
+ if 0 == len(bboxlist):
+ bboxlist = np.zeros((1, BB, 5))
+
+ return bboxlist
+
+def flip_detect(net, img, device):
+ img = cv2.flip(img, 1)
+ b = detect(net, img, device)
+
+ bboxlist = np.zeros(b.shape)
+ bboxlist[:, 0] = img.shape[1] - b[:, 2]
+ bboxlist[:, 1] = b[:, 1]
+ bboxlist[:, 2] = img.shape[1] - b[:, 0]
+ bboxlist[:, 3] = b[:, 3]
+ bboxlist[:, 4] = b[:, 4]
+ return bboxlist
+
+
+def pts_to_bb(pts):
+ min_x, min_y = np.min(pts, axis=0)
+ max_x, max_y = np.max(pts, axis=0)
+ return np.array([min_x, min_y, max_x, max_y])
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/net_s3fd.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/net_s3fd.py
new file mode 100644
index 00000000..fc64313c
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/net_s3fd.py
@@ -0,0 +1,129 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class L2Norm(nn.Module):
+ def __init__(self, n_channels, scale=1.0):
+ super(L2Norm, self).__init__()
+ self.n_channels = n_channels
+ self.scale = scale
+ self.eps = 1e-10
+ self.weight = nn.Parameter(torch.Tensor(self.n_channels))
+ self.weight.data *= 0.0
+ self.weight.data += self.scale
+
+ def forward(self, x):
+ norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
+ x = x / norm * self.weight.view(1, -1, 1, 1)
+ return x
+
+
+class s3fd(nn.Module):
+ def __init__(self):
+ super(s3fd, self).__init__()
+ self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
+ self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
+
+ self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
+ self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
+
+ self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
+ self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
+ self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
+
+ self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
+ self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
+ self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
+
+ self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
+ self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
+ self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
+
+ self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
+ self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
+
+ self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
+ self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
+
+ self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
+ self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
+
+ self.conv3_3_norm = L2Norm(256, scale=10)
+ self.conv4_3_norm = L2Norm(512, scale=8)
+ self.conv5_3_norm = L2Norm(512, scale=5)
+
+ self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
+ self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
+ self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
+ self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
+ self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
+ self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
+
+ self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
+ self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
+ self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
+ self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
+ self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
+ self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
+
+ def forward(self, x):
+ h = F.relu(self.conv1_1(x))
+ h = F.relu(self.conv1_2(h))
+ h = F.max_pool2d(h, 2, 2)
+
+ h = F.relu(self.conv2_1(h))
+ h = F.relu(self.conv2_2(h))
+ h = F.max_pool2d(h, 2, 2)
+
+ h = F.relu(self.conv3_1(h))
+ h = F.relu(self.conv3_2(h))
+ h = F.relu(self.conv3_3(h))
+ f3_3 = h
+ h = F.max_pool2d(h, 2, 2)
+
+ h = F.relu(self.conv4_1(h))
+ h = F.relu(self.conv4_2(h))
+ h = F.relu(self.conv4_3(h))
+ f4_3 = h
+ h = F.max_pool2d(h, 2, 2)
+
+ h = F.relu(self.conv5_1(h))
+ h = F.relu(self.conv5_2(h))
+ h = F.relu(self.conv5_3(h))
+ f5_3 = h
+ h = F.max_pool2d(h, 2, 2)
+
+ h = F.relu(self.fc6(h))
+ h = F.relu(self.fc7(h))
+ ffc7 = h
+ h = F.relu(self.conv6_1(h))
+ h = F.relu(self.conv6_2(h))
+ f6_2 = h
+ h = F.relu(self.conv7_1(h))
+ h = F.relu(self.conv7_2(h))
+ f7_2 = h
+
+ f3_3 = self.conv3_3_norm(f3_3)
+ f4_3 = self.conv4_3_norm(f4_3)
+ f5_3 = self.conv5_3_norm(f5_3)
+
+ cls1 = self.conv3_3_norm_mbox_conf(f3_3)
+ reg1 = self.conv3_3_norm_mbox_loc(f3_3)
+ cls2 = self.conv4_3_norm_mbox_conf(f4_3)
+ reg2 = self.conv4_3_norm_mbox_loc(f4_3)
+ cls3 = self.conv5_3_norm_mbox_conf(f5_3)
+ reg3 = self.conv5_3_norm_mbox_loc(f5_3)
+ cls4 = self.fc7_mbox_conf(ffc7)
+ reg4 = self.fc7_mbox_loc(ffc7)
+ cls5 = self.conv6_2_mbox_conf(f6_2)
+ reg5 = self.conv6_2_mbox_loc(f6_2)
+ cls6 = self.conv7_2_mbox_conf(f7_2)
+ reg6 = self.conv7_2_mbox_loc(f7_2)
+
+ # max-out background label
+ chunk = torch.chunk(cls1, 4, 1)
+ bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
+ cls1 = torch.cat([bmax, chunk[3]], dim=1)
+
+ return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/sfd_detector.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/sfd_detector.py
new file mode 100644
index 00000000..8fbce152
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/sfd_detector.py
@@ -0,0 +1,59 @@
+import os
+import cv2
+from torch.utils.model_zoo import load_url
+
+from ..core import FaceDetector
+
+from .net_s3fd import s3fd
+from .bbox import *
+from .detect import *
+
+models_urls = {
+ 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
+}
+
+
+class SFDDetector(FaceDetector):
+ def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
+ super(SFDDetector, self).__init__(device, verbose)
+
+ # Initialise the face detector
+ if not os.path.isfile(path_to_detector):
+ model_weights = load_url(models_urls['s3fd'])
+ else:
+ model_weights = torch.load(path_to_detector)
+
+ self.face_detector = s3fd()
+ self.face_detector.load_state_dict(model_weights)
+ self.face_detector.to(device)
+ self.face_detector.eval()
+
+ def detect_from_image(self, tensor_or_path):
+ image = self.tensor_or_path_to_ndarray(tensor_or_path)
+
+ bboxlist = detect(self.face_detector, image, device=self.device)
+ keep = nms(bboxlist, 0.3)
+ bboxlist = bboxlist[keep, :]
+ bboxlist = [x for x in bboxlist if x[-1] > 0.5]
+
+ return bboxlist
+
+ def detect_from_batch(self, images):
+ bboxlists = batch_detect(self.face_detector, images, device=self.device)
+ keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
+ bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
+ bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
+
+ return bboxlists
+
+ @property
+ def reference_scale(self):
+ return 195
+
+ @property
+ def reference_x_shift(self):
+ return 0
+
+ @property
+ def reference_y_shift(self):
+ return 0
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/models.py b/talkingface-toolkit-main/talkingface/utils/face_detection/models.py
new file mode 100644
index 00000000..ee2dde32
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/models.py
@@ -0,0 +1,261 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+
+
+def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
+ "3x3 convolution with padding"
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3,
+ stride=strd, padding=padding, bias=bias)
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, in_planes, out_planes):
+ super(ConvBlock, self).__init__()
+ self.bn1 = nn.BatchNorm2d(in_planes)
+ self.conv1 = conv3x3(in_planes, int(out_planes / 2))
+ self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
+ self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
+ self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
+ self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
+
+ if in_planes != out_planes:
+ self.downsample = nn.Sequential(
+ nn.BatchNorm2d(in_planes),
+ nn.ReLU(True),
+ nn.Conv2d(in_planes, out_planes,
+ kernel_size=1, stride=1, bias=False),
+ )
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ residual = x
+
+ out1 = self.bn1(x)
+ out1 = F.relu(out1, True)
+ out1 = self.conv1(out1)
+
+ out2 = self.bn2(out1)
+ out2 = F.relu(out2, True)
+ out2 = self.conv2(out2)
+
+ out3 = self.bn3(out2)
+ out3 = F.relu(out3, True)
+ out3 = self.conv3(out3)
+
+ out3 = torch.cat((out1, out2, out3), 1)
+
+ if self.downsample is not None:
+ residual = self.downsample(residual)
+
+ out3 += residual
+
+ return out3
+
+
+class Bottleneck(nn.Module):
+
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
+ padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class HourGlass(nn.Module):
+ def __init__(self, num_modules, depth, num_features):
+ super(HourGlass, self).__init__()
+ self.num_modules = num_modules
+ self.depth = depth
+ self.features = num_features
+
+ self._generate_network(self.depth)
+
+ def _generate_network(self, level):
+ self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
+
+ self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
+
+ if level > 1:
+ self._generate_network(level - 1)
+ else:
+ self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
+
+ self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
+
+ def _forward(self, level, inp):
+ # Upper branch
+ up1 = inp
+ up1 = self._modules['b1_' + str(level)](up1)
+
+ # Lower branch
+ low1 = F.avg_pool2d(inp, 2, stride=2)
+ low1 = self._modules['b2_' + str(level)](low1)
+
+ if level > 1:
+ low2 = self._forward(level - 1, low1)
+ else:
+ low2 = low1
+ low2 = self._modules['b2_plus_' + str(level)](low2)
+
+ low3 = low2
+ low3 = self._modules['b3_' + str(level)](low3)
+
+ up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
+
+ return up1 + up2
+
+ def forward(self, x):
+ return self._forward(self.depth, x)
+
+
+class FAN(nn.Module):
+
+ def __init__(self, num_modules=1):
+ super(FAN, self).__init__()
+ self.num_modules = num_modules
+
+ # Base part
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.conv2 = ConvBlock(64, 128)
+ self.conv3 = ConvBlock(128, 128)
+ self.conv4 = ConvBlock(128, 256)
+
+ # Stacking part
+ for hg_module in range(self.num_modules):
+ self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
+ self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
+ self.add_module('conv_last' + str(hg_module),
+ nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
+ self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
+ self.add_module('l' + str(hg_module), nn.Conv2d(256,
+ 68, kernel_size=1, stride=1, padding=0))
+
+ if hg_module < self.num_modules - 1:
+ self.add_module(
+ 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
+ self.add_module('al' + str(hg_module), nn.Conv2d(68,
+ 256, kernel_size=1, stride=1, padding=0))
+
+ def forward(self, x):
+ x = F.relu(self.bn1(self.conv1(x)), True)
+ x = F.avg_pool2d(self.conv2(x), 2, stride=2)
+ x = self.conv3(x)
+ x = self.conv4(x)
+
+ previous = x
+
+ outputs = []
+ for i in range(self.num_modules):
+ hg = self._modules['m' + str(i)](previous)
+
+ ll = hg
+ ll = self._modules['top_m_' + str(i)](ll)
+
+ ll = F.relu(self._modules['bn_end' + str(i)]
+ (self._modules['conv_last' + str(i)](ll)), True)
+
+ # Predict heatmaps
+ tmp_out = self._modules['l' + str(i)](ll)
+ outputs.append(tmp_out)
+
+ if i < self.num_modules - 1:
+ ll = self._modules['bl' + str(i)](ll)
+ tmp_out_ = self._modules['al' + str(i)](tmp_out)
+ previous = previous + ll + tmp_out_
+
+ return outputs
+
+
+class ResNetDepth(nn.Module):
+
+ def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
+ self.inplanes = 64
+ super(ResNetDepth, self).__init__()
+ self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+ self.avgpool = nn.AvgPool2d(7)
+ self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_layer(self, block, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(self.inplanes, planes * block.expansion,
+ kernel_size=1, stride=stride, bias=False),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(block(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ x = self.avgpool(x)
+ x = x.view(x.size(0), -1)
+ x = self.fc(x)
+
+ return x
diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/utils.py b/talkingface-toolkit-main/talkingface/utils/face_detection/utils.py
new file mode 100644
index 00000000..3dc4cf3e
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/face_detection/utils.py
@@ -0,0 +1,313 @@
+from __future__ import print_function
+import os
+import sys
+import time
+import torch
+import math
+import numpy as np
+import cv2
+
+
+def _gaussian(
+ size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
+ height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
+ mean_vert=0.5):
+ # handle some defaults
+ if width is None:
+ width = size
+ if height is None:
+ height = size
+ if sigma_horz is None:
+ sigma_horz = sigma
+ if sigma_vert is None:
+ sigma_vert = sigma
+ center_x = mean_horz * width + 0.5
+ center_y = mean_vert * height + 0.5
+ gauss = np.empty((height, width), dtype=np.float32)
+ # generate kernel
+ for i in range(height):
+ for j in range(width):
+ gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
+ sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
+ if normalize:
+ gauss = gauss / np.sum(gauss)
+ return gauss
+
+
+def draw_gaussian(image, point, sigma):
+ # Check if the gaussian is inside
+ ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
+ br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
+ if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
+ return image
+ size = 6 * sigma + 1
+ g = _gaussian(size)
+ g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
+ g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
+ img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
+ img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
+ assert (g_x[0] > 0 and g_y[1] > 0)
+ image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
+ ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
+ image[image > 1] = 1
+ return image
+
+
+def transform(point, center, scale, resolution, invert=False):
+ """Generate and affine transformation matrix.
+
+ Given a set of points, a center, a scale and a targer resolution, the
+ function generates and affine transformation matrix. If invert is ``True``
+ it will produce the inverse transformation.
+
+ Arguments:
+ point {torch.tensor} -- the input 2D point
+ center {torch.tensor or numpy.array} -- the center around which to perform the transformations
+ scale {float} -- the scale of the face/object
+ resolution {float} -- the output resolution
+
+ Keyword Arguments:
+ invert {bool} -- define wherever the function should produce the direct or the
+ inverse transformation matrix (default: {False})
+ """
+ _pt = torch.ones(3)
+ _pt[0] = point[0]
+ _pt[1] = point[1]
+
+ h = 200.0 * scale
+ t = torch.eye(3)
+ t[0, 0] = resolution / h
+ t[1, 1] = resolution / h
+ t[0, 2] = resolution * (-center[0] / h + 0.5)
+ t[1, 2] = resolution * (-center[1] / h + 0.5)
+
+ if invert:
+ t = torch.inverse(t)
+
+ new_point = (torch.matmul(t, _pt))[0:2]
+
+ return new_point.int()
+
+
+def crop(image, center, scale, resolution=256.0):
+ """Center crops an image or set of heatmaps
+
+ Arguments:
+ image {numpy.array} -- an rgb image
+ center {numpy.array} -- the center of the object, usually the same as of the bounding box
+ scale {float} -- scale of the face
+
+ Keyword Arguments:
+ resolution {float} -- the size of the output cropped image (default: {256.0})
+
+ Returns:
+ [type] -- [description]
+ """ # Crop around the center point
+ """ Crops the image around the center. Input is expected to be an np.ndarray """
+ ul = transform([1, 1], center, scale, resolution, True)
+ br = transform([resolution, resolution], center, scale, resolution, True)
+ # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
+ if image.ndim > 2:
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0],
+ image.shape[2]], dtype=np.int32)
+ newImg = np.zeros(newDim, dtype=np.uint8)
+ else:
+ newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
+ newImg = np.zeros(newDim, dtype=np.uint8)
+ ht = image.shape[0]
+ wd = image.shape[1]
+ newX = np.array(
+ [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
+ newY = np.array(
+ [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
+ oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
+ oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
+ newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
+ ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
+ newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
+ interpolation=cv2.INTER_LINEAR)
+ return newImg
+
+
+def get_preds_fromhm(hm, center=None, scale=None):
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the center
+ and the scale is provided the function will return the points also in
+ the original coordinate frame.
+
+ Arguments:
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
+
+ Keyword Arguments:
+ center {torch.tensor} -- the center of the bounding box (default: {None})
+ scale {float} -- face scale (default: {None})
+ """
+ max, idx = torch.max(
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
+ idx += 1
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
+
+ for i in range(preds.size(0)):
+ for j in range(preds.size(1)):
+ hm_ = hm[i, j, :]
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
+ diff = torch.FloatTensor(
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
+ preds[i, j].add_(diff.sign_().mul_(.25))
+
+ preds.add_(-.5)
+
+ preds_orig = torch.zeros(preds.size())
+ if center is not None and scale is not None:
+ for i in range(hm.size(0)):
+ for j in range(hm.size(1)):
+ preds_orig[i, j] = transform(
+ preds[i, j], center, scale, hm.size(2), True)
+
+ return preds, preds_orig
+
+def get_preds_fromhm_batch(hm, centers=None, scales=None):
+ """Obtain (x,y) coordinates given a set of N heatmaps. If the centers
+ and the scales is provided the function will return the points also in
+ the original coordinate frame.
+
+ Arguments:
+ hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
+
+ Keyword Arguments:
+ centers {torch.tensor} -- the centers of the bounding box (default: {None})
+ scales {float} -- face scales (default: {None})
+ """
+ max, idx = torch.max(
+ hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
+ idx += 1
+ preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
+ preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
+ preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
+
+ for i in range(preds.size(0)):
+ for j in range(preds.size(1)):
+ hm_ = hm[i, j, :]
+ pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
+ if pX > 0 and pX < 63 and pY > 0 and pY < 63:
+ diff = torch.FloatTensor(
+ [hm_[pY, pX + 1] - hm_[pY, pX - 1],
+ hm_[pY + 1, pX] - hm_[pY - 1, pX]])
+ preds[i, j].add_(diff.sign_().mul_(.25))
+
+ preds.add_(-.5)
+
+ preds_orig = torch.zeros(preds.size())
+ if centers is not None and scales is not None:
+ for i in range(hm.size(0)):
+ for j in range(hm.size(1)):
+ preds_orig[i, j] = transform(
+ preds[i, j], centers[i], scales[i], hm.size(2), True)
+
+ return preds, preds_orig
+
+def shuffle_lr(parts, pairs=None):
+ """Shuffle the points left-right according to the axis of symmetry
+ of the object.
+
+ Arguments:
+ parts {torch.tensor} -- a 3D or 4D object containing the
+ heatmaps.
+
+ Keyword Arguments:
+ pairs {list of integers} -- [order of the flipped points] (default: {None})
+ """
+ if pairs is None:
+ pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
+ 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
+ 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
+ 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
+ 62, 61, 60, 67, 66, 65]
+ if parts.ndimension() == 3:
+ parts = parts[pairs, ...]
+ else:
+ parts = parts[:, pairs, ...]
+
+ return parts
+
+
+def flip(tensor, is_label=False):
+ """Flip an image or a set of heatmaps left-right
+
+ Arguments:
+ tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
+
+ Keyword Arguments:
+ is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
+ """
+ if not torch.is_tensor(tensor):
+ tensor = torch.from_numpy(tensor)
+
+ if is_label:
+ tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
+ else:
+ tensor = tensor.flip(tensor.ndimension() - 1)
+
+ return tensor
+
+# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
+
+
+def appdata_dir(appname=None, roaming=False):
+ """ appdata_dir(appname=None, roaming=False)
+
+ Get the path to the application directory, where applications are allowed
+ to write user specific files (e.g. configurations). For non-user specific
+ data, consider using common_appdata_dir().
+ If appname is given, a subdir is appended (and created if necessary).
+ If roaming is True, will prefer a roaming directory (Windows Vista/7).
+ """
+
+ # Define default user directory
+ userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
+ if userDir is None:
+ userDir = os.path.expanduser('~')
+ if not os.path.isdir(userDir): # pragma: no cover
+ userDir = '/var/tmp' # issue #54
+
+ # Get system app data dir
+ path = None
+ if sys.platform.startswith('win'):
+ path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
+ path = (path2 or path1) if roaming else (path1 or path2)
+ elif sys.platform.startswith('darwin'):
+ path = os.path.join(userDir, 'Library', 'Application Support')
+ # On Linux and as fallback
+ if not (path and os.path.isdir(path)):
+ path = userDir
+
+ # Maybe we should store things local to the executable (in case of a
+ # portable distro or a frozen application that wants to be portable)
+ prefix = sys.prefix
+ if getattr(sys, 'frozen', None):
+ prefix = os.path.abspath(os.path.dirname(sys.executable))
+ for reldir in ('settings', '../settings'):
+ localpath = os.path.abspath(os.path.join(prefix, reldir))
+ if os.path.isdir(localpath): # pragma: no cover
+ try:
+ open(os.path.join(localpath, 'test.write'), 'wb').close()
+ os.remove(os.path.join(localpath, 'test.write'))
+ except IOError:
+ pass # We cannot write in this directory
+ else:
+ path = localpath
+ break
+
+ # Get path specific for this app
+ if appname:
+ if path == userDir:
+ appname = '.' + appname.lstrip('.') # Make it a hidden directory
+ path = os.path.join(path, appname)
+ if not os.path.isdir(path): # pragma: no cover
+ os.mkdir(path)
+
+ # Done
+ return path
diff --git a/talkingface-toolkit-main/talkingface/utils/logger.py b/talkingface-toolkit-main/talkingface/utils/logger.py
new file mode 100644
index 00000000..855dbb94
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/logger.py
@@ -0,0 +1,95 @@
+import logging
+import os
+import colorlog
+import re
+import hashlib
+from talkingface.utils.utils import get_local_time, ensure_dir
+from colorama import init
+
+log_colors_config = {
+ "DEBUG": "cyan",
+ "WARNING": "yellow",
+ "ERROR": "red",
+ "CRITICAL": "red",
+}
+
+class RemoveColorFilter(logging.Filter):
+ def filter(self, record):
+ if record:
+ ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
+ record.msg = ansi_escape.sub("", str(record.msg))
+ return True
+
+def set_color(log, color, highlight=True):
+ color_set = ["black", "red", "green", "yellow", "blue", "pink", "cyan", "white"]
+ try:
+ index = color_set.index(color)
+ except:
+ index = len(color_set) - 1
+ prev_log = "\033["
+ if highlight:
+ prev_log += "1;3"
+ else:
+ prev_log += "0;3"
+ prev_log += str(index) + "m"
+ return prev_log + log + "\033[0m"
+
+def init_logger(config):
+ """
+ A logger that can show a message on standard output and write it into the
+ file named `filename` simultaneously.
+ All the message that you want to log MUST be str.
+
+ Args:
+ config (Config): An instance object of Config, used to record parameter information.
+
+ Example:
+ >>> logger = logging.getLogger(config)
+ >>> logger.debug(train_state)
+ >>> logger.info(train_result)
+ """
+ init(autoreset=True)
+ LOGROOT = "./log/"
+ dir_name = os.path.dirname(LOGROOT)
+ ensure_dir(dir_name)
+ model_name = os.path.join(dir_name, config["model"])
+ ensure_dir(model_name)
+ config_str = "".join([str(key) for key in config.final_config_dict.values()])
+ md5 = hashlib.md5(config_str.encode(encoding="utf-8")).hexdigest()[:6]
+ logfilename = "{}/{}-{}-{}-{}.log".format(
+ config["model"], config["model"], config["dataset"], get_local_time(), md5
+ )
+
+ logfilepath = os.path.join(LOGROOT, logfilename)
+
+ filefmt = "%(asctime)-15s %(levelname)s %(message)s"
+ filedatefmt = "%a %d %b %Y %H:%M:%S"
+ fileformatter = logging.Formatter(filefmt, filedatefmt)
+
+ sfmt = "%(log_color)s%(asctime)-15s %(levelname)s %(message)s"
+ sdatefmt = "%d %b %H:%M"
+ sformatter = colorlog.ColoredFormatter(sfmt, sdatefmt, log_colors=log_colors_config)
+ if config["state"] is None or config["state"].lower() == "info":
+ level = logging.INFO
+ elif config["state"].lower() == "debug":
+ level = logging.DEBUG
+ elif config["state"].lower() == "error":
+ level = logging.ERROR
+ elif config["state"].lower() == "warning":
+ level = logging.WARNING
+ elif config["state"].lower() == "critical":
+ level = logging.CRITICAL
+ else:
+ level = logging.INFO
+
+ fh = logging.FileHandler(logfilepath)
+ fh.setLevel(level)
+ fh.setFormatter(fileformatter)
+ remove_color_filter = RemoveColorFilter()
+ fh.addFilter(remove_color_filter)
+
+ sh = logging.StreamHandler()
+ sh.setLevel(level)
+ sh.setFormatter(sformatter)
+
+ logging.basicConfig(level=level, handlers=[sh, fh])
\ No newline at end of file
diff --git a/talkingface-toolkit-main/talkingface/utils/utils.py b/talkingface-toolkit-main/talkingface/utils/utils.py
new file mode 100644
index 00000000..a5019491
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/utils.py
@@ -0,0 +1,455 @@
+import datetime
+import importlib
+import os
+import random
+import pandas as pd
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from torch.utils.tensorboard import SummaryWriter
+from texttable import Texttable
+
+def get_local_time():
+ r"""Get current time
+
+ Returns:
+ str: current time
+ """
+ cur = datetime.datetime.now()
+ cur = cur.strftime("%b-%d-%Y_%H-%M-%S")
+
+ return cur
+
+
+def ensure_dir(dir_path):
+ r"""Make sure the directory exists, if it does not exist, create it
+
+ Args:
+ dir_path (str): directory path
+
+ """
+ if not os.path.exists(dir_path):
+ os.makedirs(dir_path)
+
+def get_model(model_name):
+ r"""Automatically select model class based on model name
+
+ Args:
+ model_name (str): model name
+
+ Returns:
+ Recommender: model class
+ """
+ model_submodule = [
+ "audio_driven_talkingface",
+ "image_driven_talkingface",
+ "nerf_based_talkingface",
+ "text_to_speech",
+ "voice_conversion"
+
+ ]
+
+ model_file_name = model_name.lower()
+ model_module = None
+ for submodule in model_submodule:
+ module_path = ".".join(["talkingface.model", submodule, model_file_name])
+ if importlib.util.find_spec(module_path, __name__):
+ model_module = importlib.import_module(module_path, __name__)
+ break
+
+ if model_module is None:
+ raise ValueError(
+ "`model_name` [{}] is not the name of an existing model.".format(model_name)
+ )
+ model_class = getattr(model_module, model_name)
+ return model_class
+
+def get_trainer(model_name):
+ r"""Automatically select trainer class based on model name
+
+ Args:
+ model_name (str): model name
+
+ Returns:
+ Trainer: trainer class
+ """
+ try:
+ return getattr(
+ importlib.import_module("talkingface.trainer"), model_name + "Trainer"
+ )
+ except AttributeError:
+ raise AttributeError(
+ "There is no trainer named `{}`".format(model_name)
+ )
+
+def early_stopping(value, best, cur_step, max_step, bigger=False):
+ r"""validation-based early stopping
+
+ Args:
+ value (float): current result
+ best (float): best result
+ cur_step (int): the number of consecutive steps that did not exceed the best result
+ max_step (int): threshold steps for stopping
+ bigger (bool, optional): whether the bigger the better
+
+ Returns:
+ tuple:
+ - float,
+ best result after this step
+ - int,
+ the number of consecutive steps that did not exceed the best result after this step
+ - bool,
+ whether to stop
+ - bool,
+ whether to update
+ """
+ stop_flag = False
+ update_flag = False
+ if bigger:
+ if value >= best:
+ cur_step = 0
+ best = value
+ update_flag = True
+ else:
+ cur_step += 1
+ if cur_step > max_step:
+ stop_flag = True
+ else:
+ if value <= best:
+ cur_step = 0
+ best = value
+ update_flag = True
+ else:
+ cur_step += 1
+ if cur_step > max_step:
+ stop_flag = True
+ return best, cur_step, stop_flag, update_flag
+
+
+def calculate_valid_score(valid_result, valid_netric=None):
+ r"""Calculate the valid score according to the valid result and valid metric
+
+ Args:
+ valid_result (dict): valid result
+ valid_metric (Metric or None, optional): valid metric
+
+ Returns:
+ float: valid score
+ """
+ if valid_netric is None:
+ return valid_result
+ else:
+ return valid_result[valid_netric]
+
+def dict2str(result_dict):
+ r"""convert result dict to str
+
+ Args:
+ result_dict (dict): result dict
+
+ Returns:
+ str: result str
+ """
+
+ return " ".join(
+ [str(metric) + " : " + str(value) for metric, value in result_dict.items()]
+ )
+
+
+def init_seed(seed, reproducibility):
+ r"""init random seed for random functions in numpy, torch, cuda and cudnn
+
+ Args:
+ seed (int): random seed
+ reproducibility (bool): Whether to require reproducibility
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if reproducibility:
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+ else:
+ torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.deterministic = False
+
+def get_tensorboard(logger):
+ r"""Creates a SummaryWriter of Tensorboard that can log PyTorch models and metrics into a directory for
+ visualization within the TensorBoard UI.
+ For the convenience of the user, the naming rule of the SummaryWriter's log_dir is the same as the logger.
+
+ Args:
+ logger: its output filename is used to name the SummaryWriter's log_dir.
+ If the filename is not available, we will name the log_dir according to the current time.
+
+ Returns:
+ SummaryWriter: it will write out events and summaries to the event file.
+ """
+ base_path = "log_tensorboard"
+
+ dir_name = None
+ for handler in logger.handlers:
+ if hasattr(handler, "baseFilename"):
+ dir_name = os.path.basename(getattr(handler, "baseFilename")).split(".")[0]
+ break
+ if dir_name is None:
+ dir_name = "{}-{}".format("model", get_local_time())
+
+ dir_path = os.path.join(base_path, dir_name)
+ writer = SummaryWriter(dir_path)
+ return writer
+
+def get_gpu_usage(device=None):
+ r"""Return the reserved memory and total memory of given device in a string.
+ Args:
+ device: cuda.device. It is the device that the model run on.
+
+ Returns:
+ str: it contains the info about reserved memory and total memory of given device.
+ """
+
+ reserved = torch.cuda.max_memory_reserved(device) / 1024**3
+ total = torch.cuda.get_device_properties(device).total_memory / 1024**3
+
+ return "{:.2f} G/{:.2f} G".format(reserved, total)
+
+def get_flops(model, dataset, device, logger, transform, verbose=False):
+ r"""Given a model and dataset to the model, compute the per-operator flops
+ of the given model.
+ Args:
+ model: the model to compute flop counts.
+ dataset: dataset that are passed to `model` to count flops.
+ device: cuda.device. It is the device that the model run on.
+ verbose: whether to print information of modules.
+
+ Returns:
+ total_ops: the number of flops for each operation.
+ """
+ # if model.type == ModelType.DECISIONTREE:
+ # return 1
+ # if model.__class__.__name__ == "Pop":
+ # return 1
+
+ import copy
+
+ model = copy.deepcopy(model)
+
+ def count_normalization(m, x, y):
+ x = x[0]
+ flops = torch.DoubleTensor([2 * x.numel()])
+ m.total_ops += flops
+
+ def count_embedding(m, x, y):
+ x = x[0]
+ nelements = x.numel()
+ hiddensize = y.shape[-1]
+ m.total_ops += nelements * hiddensize
+
+ class TracingAdapter(torch.nn.Module):
+ def __init__(self, rec_model):
+ super().__init__()
+ self.model = rec_model
+
+ def forward(self, interaction):
+ return self.model.predict(interaction)
+
+ custom_ops = {
+ torch.nn.Embedding: count_embedding,
+ torch.nn.LayerNorm: count_normalization,
+ }
+ wrapper = TracingAdapter(model)
+ inter = dataset[torch.tensor([1])].to(device)
+ inter = transform(dataset, inter)
+ inputs = (inter,)
+ from thop.profile import register_hooks
+ from thop.vision.basic_hooks import count_parameters
+
+ handler_collection = {}
+ fn_handles = []
+ params_handles = []
+ types_collection = set()
+ if custom_ops is None:
+ custom_ops = {}
+
+ def add_hooks(m: nn.Module):
+ m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64))
+ m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64))
+
+ m_type = type(m)
+
+ fn = None
+ if m_type in custom_ops:
+ fn = custom_ops[m_type]
+ if m_type not in types_collection and verbose:
+ logger.info("Customize rule %s() %s." % (fn.__qualname__, m_type))
+ elif m_type in register_hooks:
+ fn = register_hooks[m_type]
+ if m_type not in types_collection and verbose:
+ logger.info("Register %s() for %s." % (fn.__qualname__, m_type))
+ else:
+ if m_type not in types_collection and verbose:
+ logger.warning(
+ "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params."
+ % m_type
+ )
+
+ if fn is not None:
+ handle_fn = m.register_forward_hook(fn)
+ handle_paras = m.register_forward_hook(count_parameters)
+ handler_collection[m] = (
+ handle_fn,
+ handle_paras,
+ )
+ fn_handles.append(handle_fn)
+ params_handles.append(handle_paras)
+ types_collection.add(m_type)
+
+ prev_training_status = wrapper.training
+
+ wrapper.eval()
+ wrapper.apply(add_hooks)
+
+ with torch.no_grad():
+ wrapper(*inputs)
+
+ def dfs_count(module: nn.Module, prefix="\t"):
+ total_ops, total_params = module.total_ops.item(), 0
+ ret_dict = {}
+ for n, m in module.named_children():
+ next_dict = {}
+ if m in handler_collection and not isinstance(
+ m, (nn.Sequential, nn.ModuleList)
+ ):
+ m_ops, m_params = m.total_ops.item(), m.total_params.item()
+ else:
+ m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t")
+ ret_dict[n] = (m_ops, m_params, next_dict)
+ total_ops += m_ops
+ total_params += m_params
+
+ return total_ops, total_params, ret_dict
+
+ total_ops, total_params, ret_dict = dfs_count(wrapper)
+
+ # reset wrapper to original status
+ wrapper.train(prev_training_status)
+ for m, (op_handler, params_handler) in handler_collection.items():
+ m._buffers.pop("total_ops")
+ m._buffers.pop("total_params")
+ for i in range(len(fn_handles)):
+ fn_handles[i].remove()
+ params_handles[i].remove()
+
+ return total_ops
+
+def list_to_latex(convert_list, bigger_flag=True, subset_columns=[]):
+ result = {}
+ for d in convert_list:
+ for key, value in d.items():
+ if key in result:
+ result[key].append(value)
+ else:
+ result[key] = [value]
+
+ df = pd.DataFrame.from_dict(result, orient="index").T
+
+ if len(subset_columns) == 0:
+ tex = df.to_latex(index=False)
+ return df, tex
+
+ def bold_func(x, bigger_flag):
+ if bigger_flag:
+ return np.where(x == np.max(x.to_numpy()), "font-weight:bold", None)
+ else:
+ return np.where(x == np.min(x.to_numpy()), "font-weight:bold", None)
+
+ style = df.style
+ style.apply(bold_func, bigger_flag=bigger_flag, subset=subset_columns)
+ style.format(precision=4)
+
+ num_column = len(df.columns)
+ column_format = "c" * num_column
+ tex = style.hide(axis="index").to_latex(
+ caption="Result Table",
+ label="Result Table",
+ convert_css=True,
+ hrules=True,
+ column_format=column_format,
+ )
+
+ return df, tex
+
+def get_environment(config):
+ gpu_usage = (
+ get_gpu_usage(config["device"])
+ if torch.cuda.is_available() and config["use_gpu"]
+ else "0.0 / 0.0"
+ )
+
+ import psutil
+
+ memory_used = psutil.Process(os.getpid()).memory_info().rss / 1024**3
+ memory_total = psutil.virtual_memory()[0] / 1024**3
+ memory_usage = "{:.2f} G/{:.2f} G".format(memory_used, memory_total)
+ cpu_usage = "{:.2f} %".format(psutil.cpu_percent(interval=1))
+ """environment_data = [
+ {"Environment": "CPU", "Usage": cpu_usage,},
+ {"Environment": "GPU", "Usage": gpu_usage, },
+ {"Environment": "Memory", "Usage": memory_usage, },
+ ]"""
+
+ table = Texttable()
+ table.set_cols_align(["l", "c"])
+ table.set_cols_valign(["m", "m"])
+ table.add_rows(
+ [
+ ["Environment", "Usage"],
+ ["CPU", cpu_usage],
+ ["GPU", gpu_usage],
+ ["Memory", memory_usage],
+ ]
+ )
+
+ return table
+
+def get_preprocess(dataset_name):
+ r"""Automatically select dataset preprocess class based on dataset name
+ """
+ model_file_name = dataset_name.lower()
+ module_path = "talkingface.utils.data_process"
+ model_module = importlib.import_module(module_path, __name__)
+ try:
+ preprocess_class = getattr(model_module, dataset_name+'Preprocess')
+ except:
+ raise ValueError(
+ "`dataset_name` [{}] is not the name of an existing dataset.".format(dataset_name)
+ )
+ return preprocess_class
+
+def create_dataset(config):
+ r"""Automatically select dataset class based on dataset name
+ """
+ model_name = config['model']
+ dataset_file_name = model_name.lower()+'_dataset'
+ module_path = ".".join(["talkingface.data.dataset", dataset_file_name])
+ if importlib.util.find_spec(module_path, __name__):
+ dataset_module = importlib.import_module(module_path, __name__)
+ if dataset_module is None:
+ raise ValueError(
+ "`dataset_file_name` [{}] is not the name of an existing dataset.".format(dataset_file_name)
+ )
+ dataset_class = getattr(dataset_module, model_name+'Dataset')
+
+ return dataset_class(config, config['train_filelist']), dataset_class(config, config['val_filelist'])
+
+
+
+
+
+
+
+
diff --git a/talkingface-toolkit-main/talkingface/utils/wandblogger.py b/talkingface-toolkit-main/talkingface/utils/wandblogger.py
new file mode 100644
index 00000000..8d958474
--- /dev/null
+++ b/talkingface-toolkit-main/talkingface/utils/wandblogger.py
@@ -0,0 +1,57 @@
+class WandbLogger(object):
+ """WandbLogger to log metrics to Weights and Biases."""
+
+ def __init__(self, config):
+ """
+ Args:
+ config (dict): A dictionary of parameters used by RecBole.
+ """
+ self.config = config
+ self.log_wandb = config.log_wandb
+ self.setup()
+
+ def setup(self):
+ if self.log_wandb:
+ try:
+ import wandb
+
+ self._wandb = wandb
+ except ImportError:
+ raise ImportError(
+ "To use the Weights and Biases Logger please install wandb."
+ "Run `pip install wandb` to install it."
+ )
+
+ # Initialize a W&B run
+ if self._wandb.run is None:
+ self._wandb.init(project=self.config.wandb_project, config=self.config)
+
+ self._set_steps()
+
+ def log_metrics(self, metrics, head="train", commit=True):
+ if self.log_wandb:
+ if head:
+ metrics = self._add_head_to_metrics(metrics, head)
+ self._wandb.log(metrics, commit=commit)
+ else:
+ self._wandb.log(metrics, commit=commit)
+
+ def log_eval_metrics(self, metrics, head="eval"):
+ if self.log_wandb:
+ metrics = self._add_head_to_metrics(metrics, head)
+ for k, v in metrics.items():
+ self._wandb.run.summary[k] = v
+
+ def _set_steps(self):
+ self._wandb.define_metric("train/*", step_metric="train_step")
+ self._wandb.define_metric("valid/*", step_metric="valid_step")
+
+ def _add_head_to_metrics(self, metrics, head):
+ head_metrics = dict()
+ for k, v in metrics.items():
+ if "_step" in k:
+ head_metrics[k] = v
+ else:
+ head_metrics[f"{head}/{k}"] = v
+
+ return head_metrics