From d04b679f554859daff1c3d392e79a5062a747fd4 Mon Sep 17 00:00:00 2001 From: Aquariuslyh <144698988+Aquariuslyh@users.noreply.github.com> Date: Tue, 30 Jan 2024 23:04:41 +0800 Subject: [PATCH] Add files via upload --- talkingface-toolkit-main/README.md | 216 ++++ .../checkpoints/README.md | 16 + talkingface-toolkit-main/dataset/README.md | 41 + talkingface-toolkit-main/environment.yml | 138 ++ talkingface-toolkit-main/requirements.txt | 114 ++ talkingface-toolkit-main/run_talkingface.py | 24 + .../talkingface/config/__init__.py | 1 + .../talkingface/config/configurator.py | 335 +++++ .../talkingface/data/__init__.py | 0 .../data/dataprocess/ADNeRFmaster_process.py | 0 .../talkingface/data/dataprocess/__init__.py | 1 + .../data/dataprocess/wav2lip_process.py | 253 ++++ .../data/dataset/AD_NeRF_master_dataset.py | 282 ++++ .../talkingface/data/dataset/__init__.py | 3 + .../talkingface/data/dataset/dataset.py | 25 + .../data/dataset/wav2lip_dataset.py | 158 +++ .../talkingface/evaluator/__init__.py | 5 + .../talkingface/evaluator/base_metric.py | 98 ++ .../talkingface/evaluator/evaluator.py | 22 + .../talkingface/evaluator/metric_models.py | 105 ++ .../talkingface/evaluator/metrics.py | 243 ++++ .../talkingface/evaluator/register.py | 46 + .../talkingface/model/__init__.py | 0 .../talkingface/model/abstract_speech.py | 75 ++ .../talkingface/model/abstract_talkingface.py | 75 ++ .../audio_driven_talkingface/__init__.py | 1 + .../image_driven_talkingface/__init__.py | 0 .../talkingface/model/layers.py | 44 + .../ADNeRFmaster/__init__.py | 142 +++ .../ADNeRFmaster/deepspeech_features.py | 275 ++++ .../ADNeRFmaster/deepspeech_store.py | 172 +++ .../ADNeRFmaster/extract_ds_features.py | 129 ++ .../ADNeRFmaster/extract_wav.py | 87 ++ .../ADNeRFmaster/fea_win.py | 11 + .../ADNeRFmaster/load_audface.py | 114 ++ .../ADNeRFmaster/run_nerf.py | 1136 +++++++++++++++++ .../ADNeRFmaster/run_nerf_helpers.py | 454 +++++++ .../nerf_based_talkingface/AD_NeRF_master.py | 114 ++ .../model/nerf_based_talkingface/__init__.py | 1 + .../model/nerf_based_talkingface/wav2lip.py | 358 ++++++ .../talkingface/properties/dataset/lrs2.yaml | 10 + .../properties/model/AD_NeRF_master.yaml | 37 + .../properties/model/AD_NeRF_master/May.yaml | 32 + .../model/AD_NeRF_master/McStay.yaml | 32 + .../model/AD_NeRF_master/Nadella.yaml | 32 + .../model/AD_NeRF_master/Obama1.yaml | 32 + .../model/AD_NeRF_master/Obama2.yaml | 32 + .../talkingface/properties/model/Wav2Lip.yaml | 55 + .../talkingface/properties/overall.yaml | 31 + .../talkingface/quick_start/__init__.py | 5 + .../talkingface/quick_start/quick_start.py | 105 ++ .../trainer/AD_NeRF_masterTrainer.py | 965 ++++++++++++++ .../talkingface/trainer/__init__.py | 2 + .../talkingface/trainer/trainer.py | 557 ++++++++ .../talkingface/utils/__init__.py | 43 + .../talkingface/utils/argument_list.py | 48 + .../talkingface/utils/data_process.py | 95 ++ .../talkingface/utils/enum_type.py | 13 + .../utils/face_detection/README.md | 1 + .../utils/face_detection/__init__.py | 7 + .../talkingface/utils/face_detection/api.py | 79 ++ .../face_detection/detection/__init__.py | 1 + .../utils/face_detection/detection/core.py | 130 ++ .../face_detection/detection/sfd/__init__.py | 1 + .../face_detection/detection/sfd/bbox.py | 129 ++ .../face_detection/detection/sfd/detect.py | 112 ++ .../face_detection/detection/sfd/net_s3fd.py | 129 ++ .../detection/sfd/sfd_detector.py | 59 + .../utils/face_detection/models.py | 261 ++++ .../talkingface/utils/face_detection/utils.py | 313 +++++ .../talkingface/utils/logger.py | 95 ++ .../talkingface/utils/utils.py | 455 +++++++ .../talkingface/utils/wandblogger.py | 57 + 73 files changed, 9269 insertions(+) create mode 100644 talkingface-toolkit-main/README.md create mode 100644 talkingface-toolkit-main/checkpoints/README.md create mode 100644 talkingface-toolkit-main/dataset/README.md create mode 100644 talkingface-toolkit-main/environment.yml create mode 100644 talkingface-toolkit-main/requirements.txt create mode 100644 talkingface-toolkit-main/run_talkingface.py create mode 100644 talkingface-toolkit-main/talkingface/config/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/config/configurator.py create mode 100644 talkingface-toolkit-main/talkingface/data/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/data/dataprocess/ADNeRFmaster_process.py create mode 100644 talkingface-toolkit-main/talkingface/data/dataprocess/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/data/dataprocess/wav2lip_process.py create mode 100644 talkingface-toolkit-main/talkingface/data/dataset/AD_NeRF_master_dataset.py create mode 100644 talkingface-toolkit-main/talkingface/data/dataset/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/data/dataset/dataset.py create mode 100644 talkingface-toolkit-main/talkingface/data/dataset/wav2lip_dataset.py create mode 100644 talkingface-toolkit-main/talkingface/evaluator/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/evaluator/base_metric.py create mode 100644 talkingface-toolkit-main/talkingface/evaluator/evaluator.py create mode 100644 talkingface-toolkit-main/talkingface/evaluator/metric_models.py create mode 100644 talkingface-toolkit-main/talkingface/evaluator/metrics.py create mode 100644 talkingface-toolkit-main/talkingface/evaluator/register.py create mode 100644 talkingface-toolkit-main/talkingface/model/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/model/abstract_speech.py create mode 100644 talkingface-toolkit-main/talkingface/model/abstract_talkingface.py create mode 100644 talkingface-toolkit-main/talkingface/model/audio_driven_talkingface/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/model/image_driven_talkingface/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/model/layers.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_features.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_store.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_ds_features.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_wav.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/fea_win.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/load_audface.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf_helpers.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/AD_NeRF_master.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/wav2lip.py create mode 100644 talkingface-toolkit-main/talkingface/properties/dataset/lrs2.yaml create mode 100644 talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master.yaml create mode 100644 talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/May.yaml create mode 100644 talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/McStay.yaml create mode 100644 talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Nadella.yaml create mode 100644 talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama1.yaml create mode 100644 talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama2.yaml create mode 100644 talkingface-toolkit-main/talkingface/properties/model/Wav2Lip.yaml create mode 100644 talkingface-toolkit-main/talkingface/properties/overall.yaml create mode 100644 talkingface-toolkit-main/talkingface/quick_start/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/quick_start/quick_start.py create mode 100644 talkingface-toolkit-main/talkingface/trainer/AD_NeRF_masterTrainer.py create mode 100644 talkingface-toolkit-main/talkingface/trainer/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/trainer/trainer.py create mode 100644 talkingface-toolkit-main/talkingface/utils/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/utils/argument_list.py create mode 100644 talkingface-toolkit-main/talkingface/utils/data_process.py create mode 100644 talkingface-toolkit-main/talkingface/utils/enum_type.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/README.md create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/api.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/detection/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/detection/core.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/__init__.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/bbox.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/detect.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/net_s3fd.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/sfd_detector.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/models.py create mode 100644 talkingface-toolkit-main/talkingface/utils/face_detection/utils.py create mode 100644 talkingface-toolkit-main/talkingface/utils/logger.py create mode 100644 talkingface-toolkit-main/talkingface/utils/utils.py create mode 100644 talkingface-toolkit-main/talkingface/utils/wandblogger.py diff --git a/talkingface-toolkit-main/README.md b/talkingface-toolkit-main/README.md new file mode 100644 index 00000000..b6f46dbb --- /dev/null +++ b/talkingface-toolkit-main/README.md @@ -0,0 +1,216 @@ +python meta_portrait_base_inference.py --save_dir ./saved --config ./talkingface/properties/model/meta_portrait_base_config/meta_portrait_256_eval.yaml --ckpt ./saved/ckpt_base.pth.tar + +python meta_portrait_base_main.py --config ./talkingface/properties/model/meta_portrait_base_config/meta_portrait_256_pretrain_warp.yaml --fp16 --stage Warp --task Pretrain + + + +# talkingface-toolkit +## 框架整体介绍 +### checkpoints +主要保存的是训练和评估模型所需要的额外的预训练模型,在对应文件夹的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/checkpoints/README.md)有更详细的介绍 + +### datset +存放数据集以及数据集预处理之后的数据,详细内容见dataset里的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/dataset/README.md) + +### saved +存放训练过程中保存的模型checkpoint, 训练过程中保存模型时自动创建 + +### talkingface +主要功能模块,包括所有核心代码 + +#### config +根据模型和数据集名称自动生成所有模型、数据集、训练、评估等相关的配置信息 +``` +config/ + +├── configurator.py + +``` +#### data +- dataprocess:模型特有的数据处理代码,(可以是对方仓库自己实现的音频特征提取、推理时的数据处理)。如果实现的模型有这个需求,就要建立一对应的文件 +- dataset:每个模型都要重载`torch.utils.data.Dataset` 用于加载数据。每个模型都要有一个`model_name+'_dataset.py'`文件. `__getitem__()`方法的返回值应处理成字典类型的数据。 (核心部分) +``` +data/ + +├── dataprocess + +| ├── wav2lip_process.py + +| ├── xxxx_process.py + +├── dataset + +| ├── wav2lip_dataset.py + +| ├── xxx_dataset.py +``` + +#### evaluate +主要涉及模型评估的代码 +LSE metric 需要的数据是生成的视频列表 +SSIM metric 需要的数据是生成的视频和真实的视频列表 + +#### model +实现的模型的网络和对应的方法 (核心部分) + +主要分三类: +- audio-driven (音频驱动) +- image-driven (图像驱动) +- nerf-based (基于神经辐射场的方法) + +``` +model/ + +├── audio_driven_talkingface + +| ├── wav2lip.py + +├── image_driven_talkingface + +| ├── xxxx.py + +├── nerf_based_talkingface + +| ├── xxxx.py + +├── abstract_talkingface.py + +``` + +#### properties +保存默认配置文件,包括: +- 数据集配置文件 +- 模型配置文件 +- 通用配置文件 + +需要根据对应模型和数据集增加对应的配置文件,通用配置文件`overall.yaml`一般不做修改 +``` +properties/ + +├── dataset + +| ├── xxx.yaml + +├── model + +| ├── xxx.yaml + +├── overall.yaml + +``` + +#### quick_start +通用的启动文件,根据传入参数自动配置数据集和模型,然后训练和评估(一般不需要修改) +``` +quick_start/ + +├── quick_start.py + +``` + +#### trainer +训练、评估函数的主类。在trainer中,如果可以使用基类`Trainer`实现所有功能,则不需要写一个新的。如果模型训练有一些特有部分,则需要重载`Trainer`。需要重载部分可能主要集中于: `_train_epoch()`, `_valid_epoch()`。 重载的`Trainer`应该命名为:`{model_name}Trainer` +``` +trainer/ + +├── trainer.py + +``` + +#### utils +公用的工具类,包括`s3fd`人脸检测,视频抽帧、视频抽音频方法。还包括根据参数配置找对应的模型类、数据类等方法。 +一般不需要修改,但可以适当添加一些必须的且相对普遍的数据处理文件。 + +## 使用方法 +### 环境要求 +- `python=3.8` +- `torch==1.13.1+cu116`(gpu版,若设备不支持cuda可以使用cpu版) +- `numpy==1.20.3` +- `librosa==0.10.1` + +尽量保证上面几个包的版本一致 + +提供了两种配置其他环境的方法: +``` +pip install -r requirements.txt + +or + +conda env create -f environment.yml +``` + +建议使用conda虚拟环境!!! + +### 训练和评估 + +```bash +python run_talkingface.py --model=xxxx --dataset=xxxx (--other_parameters=xxxxxx) +``` + +### 权重文件 + +- LSE评估需要的权重: syncnet_v2.model [百度网盘下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc) +- wav2lip需要的lip expert 权重:lipsync_expert.pth [百度网下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc) + +## 可选论文: +### Aduio_driven talkingface +| 模型简称 | 论文 | 代码仓库 | +|:--------:|:--------:|:--------:| +| MakeItTalk | [paper](https://arxiv.org/abs/2004.12992) | [code](https://github.com/yzhou359/MakeItTalk) | +| MEAD | [paper](https://wywu.github.io/projects/MEAD/support/MEAD.pdf) | [code](https://github.com/uniBruce/Mead) | +| RhythmicHead | [paper](https://arxiv.org/pdf/2007.08547v1.pdf) | [code](https://github.com/lelechen63/Talking-head-Generation-with-Rhythmic-Head-Motion) | +| PC-AVS | [paper](https://arxiv.org/abs/2104.11116) | [code](https://github.com/Hangz-nju-cuhk/Talking-Face_PC-AVS) | +| EVP | [paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Ji_Audio-Driven_Emotional_Video_Portraits_CVPR_2021_paper.pdf) | [code](https://github.com/jixinya/EVP) | +| LSP | [paper](https://arxiv.org/abs/2109.10595) | [code](https://github.com/YuanxunLu/LiveSpeechPortraits) | +| EAMM | [paper](https://arxiv.org/pdf/2205.15278.pdf) | [code](https://github.com/jixinya/EAMM/) | +| DiffTalk | [paper](https://arxiv.org/abs/2301.03786) | [code](https://github.com/sstzal/DiffTalk) | +| TalkLip | [paper](https://arxiv.org/pdf/2303.17480.pdf) | [code](https://github.com/Sxjdwang/TalkLip) | +| EmoGen | [paper](https://arxiv.org/pdf/2303.11548.pdf) | [code](https://github.com/sahilg06/EmoGen) | +| SadTalker | [paper](https://arxiv.org/abs/2211.12194) | [code](https://github.com/OpenTalker/SadTalker) | +| HyperLips | [paper](https://arxiv.org/abs/2310.05720) | [code](https://github.com/semchan/HyperLips) | +| PHADTF | [paper](http://arxiv.org/abs/2002.10137) | [code](https://github.com/yiranran/Audio-driven-TalkingFace-HeadPose) | +| VideoReTalking | [paper](https://arxiv.org/abs/2211.14758) | [code](https://github.com/OpenTalker/video-retalking#videoretalking--audio-based-lip-synchronization-for-talking-head-video-editing-in-the-wild-) +| | + + + +### Image_driven talkingface +| 模型简称 | 论文 | 代码仓库 | +|:--------:|:--------:|:--------:| +| PIRenderer | [paper](https://arxiv.org/pdf/2109.08379.pdf) | [code](https://github.com/RenYurui/PIRender) | +| StyleHEAT | [paper](https://arxiv.org/pdf/2203.04036.pdf) | [code](https://github.com/OpenTalker/StyleHEAT) | +| MetaPortrait | [paper](https://arxiv.org/abs/2212.08062) | [code](https://github.com/Meta-Portrait/MetaPortrait) | +| | +### Nerf-based talkingface +| 模型简称 | 论文 | 代码仓库 | +|:--------:|:--------:|:--------:| +| AD-NeRF | [paper](https://arxiv.org/abs/2103.11078) | [code](https://github.com/YudongGuo/AD-NeRF) | +| GeneFace | [paper](https://arxiv.org/abs/2301.13430) | [code](https://github.com/yerfor/GeneFace) | +| DFRF | [paper](https://arxiv.org/abs/2207.11770) | [code](https://github.com/sstzal/DFRF) | +| | +### text_to_speech +| 模型简称 | 论文 | 代码仓库 | +|:--------:|:--------:|:--------:| +| VITS | [paper](https://arxiv.org/abs/2106.06103) | [code](https://github.com/jaywalnut310/vits) | +| Glow TTS | [paper](https://arxiv.org/abs/2005.11129) | [code](https://github.com/jaywalnut310/glow-tts) | +| FastSpeech2 | [paper](https://arxiv.org/abs/2006.04558v1) | [code](https://github.com/ming024/FastSpeech2) | +| StyleTTS2 | [paper](https://arxiv.org/abs/2306.07691) | [code](https://github.com/yl4579/StyleTTS2) | +| Grad-TTS | [paper](https://arxiv.org/abs/2105.06337) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS) | +| FastSpeech | [paper](https://arxiv.org/abs/1905.09263) | [code](https://github.com/xcmyz/FastSpeech) | +| | +### voice_conversion +| 模型简称 | 论文 | 代码仓库 | +|:--------:|:--------:|:--------:| +| StarGAN-VC | [paper](http://www.kecl.ntt.co.jp/people/kameoka.hirokazu/Demos/stargan-vc2/index.html) | [code](https://github.com/kamepong/StarGAN-VC) | +| Emo-StarGAN | [paper](https://www.researchgate.net/publication/373161292_Emo-StarGAN_A_Semi-Supervised_Any-to-Many_Non-Parallel_Emotion-Preserving_Voice_Conversion) | [code](https://github.com/suhitaghosh10/emo-stargan) | +| adaptive-VC | [paper](https://arxiv.org/abs/1904.05742) | [code](https://github.com/jjery2243542/adaptive_voice_conversion) | +| DiffVC | [paper](https://arxiv.org/abs/2109.13821) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC) | +| Assem-VC | [paper](https://arxiv.org/abs/2104.00931) | [code](https://github.com/maum-ai/assem-vc) | +| | + +## 作业要求 +- 确保可以仅在命令行输入模型和数据集名称就可以训练、验证。(部分仓库没有提供训练代码的,可以不训练) +- 每个组都要提交一个README文件,写明完成的功能、最终实现的训练、验证截图、所使用的依赖、成员分工等。 + + + diff --git a/talkingface-toolkit-main/checkpoints/README.md b/talkingface-toolkit-main/checkpoints/README.md new file mode 100644 index 00000000..0a1432d6 --- /dev/null +++ b/talkingface-toolkit-main/checkpoints/README.md @@ -0,0 +1,16 @@ +这个文件夹中保存的是,模型训练或验证过程中用到的一些额外的预训练权重如: +- wav2lip中用到的syncnet权重 +- 计算合成视频lip-audio同步LSE用到的syncnet-v2权重 +- ....... + +目录结构为: +``` +checkpoints/ + +├── LSE +| ├── syncnet_v2.model () + +├── Wav2Lip +| ├── lipsync_expert.pth () + +``` diff --git a/talkingface-toolkit-main/dataset/README.md b/talkingface-toolkit-main/dataset/README.md new file mode 100644 index 00000000..8a2201ad --- /dev/null +++ b/talkingface-toolkit-main/dataset/README.md @@ -0,0 +1,41 @@ +这个文件夹中保存的是数据集如: +- lrw +- lrs2 +- mead +- ....... + +数据集处理的一般格式为: + +``` +dataset/ + +├── lrs2 + +| ├── data (存放数据集的原始数据) + +| ├── filelist (保存的是数据集划分) + +| │ ├── train.txt + +| │ ├── val.txt + +| │ ├── test.txt + +| ├── preprocessed_data (具体路径内容可以参考talkingface.utils.data_preprocess文件中处理lrs2时候的路径,主要存储的是视频抽帧后的图像文件和音频文件) +``` + +preprocessed_data的数据路径一般表示为: +``` +preprocessed_root (lrs2_preprocessed)/ + +├── list of folders + +| ├── Folders with five-digit numbered video IDs + +| │ ├── *.jpg + +| │ ├── audio.wav + +``` + +数据集存储尽量按照这个格式来,数据集的划分也尽量按照train.txt val.txt和test.txt文件来 diff --git a/talkingface-toolkit-main/environment.yml b/talkingface-toolkit-main/environment.yml new file mode 100644 index 00000000..09a8595b --- /dev/null +++ b/talkingface-toolkit-main/environment.yml @@ -0,0 +1,138 @@ +name: torch38 +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - ca-certificates=2023.08.22=h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - ncurses=6.4=h6a678d5_0 + - openssl=3.0.12=h7f8727e_0 + - pip=23.3=py38h06a4308_0 + - python=3.8.18=h955ad1f_0 + - readline=8.2=h5eee18b_0 + - setuptools=68.0.0=py38h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.41.2=py38h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - absl-py==2.0.0 + - addict==2.4.0 + - aiosignal==1.3.1 + - appdirs==1.4.4 + - attrs==23.1.0 + - audioread==3.0.1 + - basicsr==1.3.4.7 + - cachetools==5.3.2 + - certifi==2020.12.5 + - cffi==1.16.0 + - charset-normalizer==3.3.2 + - click==8.1.7 + - cloudpickle==3.0.0 + - colorama==0.4.6 + - colorlog==6.7.0 + - contourpy==1.1.1 + - cycler==0.12.1 + - decorator==5.1.1 + - dlib==19.22.1 + - docker-pycreds==0.4.0 + - face-alignment==1.3.5 + - ffmpeg==1.4 + - filelock==3.13.1 + - fonttools==4.44.0 + - frozenlist==1.4.0 + - future==0.18.3 + - gitdb==4.0.11 + - gitpython==3.1.40 + - glob2==0.7 + - google-auth==2.23.4 + - google-auth-oauthlib==0.4.6 + - grpcio==1.59.2 + - hyperopt==0.2.5 + - idna==3.4 + - imageio==2.9.0 + - imageio-ffmpeg==0.4.5 + - importlib-metadata==6.8.0 + - importlib-resources==6.1.0 + - joblib==1.3.2 + - jsonschema==4.19.2 + - jsonschema-specifications==2023.7.1 + - kiwisolver==1.4.5 + - kornia==0.5.5 + - lazy-loader==0.3 + - librosa==0.10.1 + - llvmlite==0.37.0 + - lmdb==1.2.1 + - lws==1.2.7 + - markdown==3.5.1 + - markupsafe==2.1.3 + - matplotlib==3.6.3 + - msgpack==1.0.7 + - networkx==3.1 + - numba==0.54.1 + - numpy==1.20.3 + - oauthlib==3.2.2 + - opencv-python==3.4.9.33 + - packaging==23.2 + - pandas==1.3.4 + - pathtools==0.1.2 + - pillow==6.2.1 + - pkgutil-resolve-name==1.3.10 + - platformdirs==3.11.0 + - plotly==5.18.0 + - pooch==1.8.0 + - protobuf==4.25.0 + - psutil==5.9.6 + - pyasn1==0.5.0 + - pyasn1-modules==0.3.0 + - pycparser==2.21 + - pyparsing==3.1.1 + - python-dateutil==2.8.2 + - python-speech-features==0.6 + - pytorch-fid==0.3.0 + - pytz==2023.3.post1 + - pywavelets==1.4.1 + - pyyaml==5.3.1 + - ray==2.6.3 + - referencing==0.30.2 + - requests==2.31.0 + - requests-oauthlib==1.3.1 + - rpds-py==0.12.0 + - rsa==4.9 + - scikit-image==0.16.2 + - scikit-learn==1.3.2 + - scipy==1.5.0 + - sentry-sdk==1.34.0 + - setproctitle==1.3.3 + - six==1.16.0 + - smmap==5.0.1 + - soundfile==0.12.1 + - soxr==0.3.7 + - tabulate==0.9.0 + - tb-nightly==2.12.0a20230126 + - tenacity==8.2.3 + - tensorboard==2.7.0 + - tensorboard-data-server==0.6.1 + - tensorboard-plugin-wit==1.8.1 + - texttable==1.7.0 + - thop==0.1.1-2209072238 + - threadpoolctl==3.2.0 + - tomli==2.0.1 + - torch==1.13.1+cu116 + - torchaudio==0.13.1+cu116 + - torchvision==0.14.1+cu116 + - tqdm==4.66.1 + - trimesh==3.9.20 + - typing-extensions==4.8.0 + - tzdata==2023.3 + - urllib3==2.0.7 + - wandb==0.15.12 + - werkzeug==3.0.1 + - yapf==0.40.2 + - zipp==3.17.0 diff --git a/talkingface-toolkit-main/requirements.txt b/talkingface-toolkit-main/requirements.txt new file mode 100644 index 00000000..1605c1fe --- /dev/null +++ b/talkingface-toolkit-main/requirements.txt @@ -0,0 +1,114 @@ +absl-py==2.0.0 +addict==2.4.0 +aiosignal==1.3.1 +appdirs==1.4.4 +attrs==23.1.0 +audioread==3.0.1 +basicsr==1.3.4.7 +cachetools==5.3.2 +certifi==2020.12.5 +cffi==1.16.0 +charset-normalizer==3.3.2 +click==8.1.7 +cloudpickle==3.0.0 +colorama==0.4.6 +colorlog==6.7.0 +contourpy==1.1.1 +cycler==0.12.1 +decorator==5.1.1 +dlib==19.22.1 +docker-pycreds==0.4.0 +face-alignment==1.3.5 +ffmpeg==1.4 +filelock==3.13.1 +fonttools==4.44.0 +frozenlist==1.4.0 +future==0.18.3 +gitdb==4.0.11 +GitPython==3.1.40 +glob2==0.7 +google-auth==2.23.4 +google-auth-oauthlib==0.4.6 +grpcio==1.59.2 +hyperopt==0.2.5 +idna==3.4 +imageio==2.9.0 +imageio-ffmpeg==0.4.5 +importlib-metadata==6.8.0 +importlib-resources==6.1.0 +joblib==1.3.2 +jsonschema==4.19.2 +jsonschema-specifications==2023.7.1 +kiwisolver==1.4.5 +kornia==0.5.5 +lazy_loader==0.3 +librosa==0.10.1 +llvmlite==0.37.0 +lmdb==1.2.1 +lws==1.2.7 +Markdown==3.5.1 +MarkupSafe==2.1.3 +matplotlib==3.6.3 +msgpack==1.0.7 +networkx==3.1 +numba==0.54.1 +numpy==1.20.3 +oauthlib==3.2.2 +opencv-python==3.4.9.33 +packaging==23.2 +pandas==1.3.4 +pathtools==0.1.2 +Pillow==6.2.1 +pkgutil_resolve_name==1.3.10 +platformdirs==3.11.0 +plotly==5.18.0 +pooch==1.8.0 +protobuf==4.25.0 +psutil==5.9.6 +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pycparser==2.21 +pyparsing==3.1.1 +python-dateutil==2.8.2 +python-speech-features==0.6 +pytorch-fid==0.3.0 +pytz==2023.3.post1 +PyWavelets==1.4.1 +PyYAML==5.3.1 +ray==2.6.3 +referencing==0.30.2 +requests==2.31.0 +requests-oauthlib==1.3.1 +rpds-py==0.12.0 +rsa==4.9 +scikit-image==0.16.2 +scikit-learn==1.3.2 +scipy==1.5.0 +sentry-sdk==1.34.0 +setproctitle==1.3.3 +six==1.16.0 +smmap==5.0.1 +soundfile==0.12.1 +soxr==0.3.7 +tabulate==0.9.0 +tb-nightly==2.12.0a20230126 +tenacity==8.2.3 +tensorboard==2.7.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +texttable==1.7.0 +thop==0.1.1.post2209072238 +threadpoolctl==3.2.0 +tomli==2.0.1 +torch==1.13.1+cu116 +torchaudio==0.13.1+cu116 +torchvision==0.14.1+cu116 +tqdm==4.66.1 +trimesh==3.9.20 +typing_extensions==4.8.0 +tzdata==2023.3 +urllib3==2.0.7 +wandb==0.15.12 +Werkzeug==3.0.1 +yapf==0.40.2 +zipp==3.17.0 diff --git a/talkingface-toolkit-main/run_talkingface.py b/talkingface-toolkit-main/run_talkingface.py new file mode 100644 index 00000000..3989d566 --- /dev/null +++ b/talkingface-toolkit-main/run_talkingface.py @@ -0,0 +1,24 @@ +import argparse +from talkingface.quick_start import run + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model", "-m", type=str, default="BPR", help="name of models") + parser.add_argument( + "--dataset", "-d", type=str, default=None, help="name of datasets" + ) + parser.add_argument("--evaluate_model_file", type=str, default=None, help="The model file you want to evaluate") + parser.add_argument("--config_files", type=str, default=None, help="config files") + + + args, _ = parser.parse_known_args() + + config_file_list = ( + args.config_files.strip().split(" ") if args.config_files else None + ) + run( + args.model, + args.dataset, + config_file_list=config_file_list, + evaluate_model_file=args.evaluate_model_file + ) \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/config/__init__.py b/talkingface-toolkit-main/talkingface/config/__init__.py new file mode 100644 index 00000000..25fcd4c9 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/config/__init__.py @@ -0,0 +1 @@ +from talkingface.config.configurator import Config \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/config/configurator.py b/talkingface-toolkit-main/talkingface/config/configurator.py new file mode 100644 index 00000000..3fad0f4c --- /dev/null +++ b/talkingface-toolkit-main/talkingface/config/configurator.py @@ -0,0 +1,335 @@ +import re +import os +import sys +import yaml +from logging import getLogger + +from talkingface.utils import( + get_model, + # Enum, + # ModelType, + # InputType, + general_arguments, + training_arguments, + evaluation_arguments, + set_color +) + +class Config(object): + """Configurator module that load the defined parameters. + + Configurator module will first load the default parameters from the fixed properties in TalkingFace and then + load parameters from the external input. + + External input supports three kind of forms: config file, command line and parameter dictionaries. + + - config file: It's a file that record the parameters to be modified or added. It should be in ``yaml`` format, + e.g. a config file is 'example.yaml', the content is: + + learning_rate: 0.001 + + train_batch_size: 2048 + + - command line: It should be in the format as '---learning_rate=0.001' + + - parameter dictionaries: It should be a dict, where the key is parameter name and the value is parameter value, + e.g. config_dict = {'learning_rate': 0.001} + + Configuration module allows the above three kind of external input format to be used together, + the priority order is as following: + + command line > parameter dictionaries > config file + + e.g. If we set learning_rate=0.01 in config file, learning_rate=0.02 in command line, + learning_rate=0.03 in parameter dictionaries. + + Finally the learning_rate is equal to 0.02. + """ + + def __init__( + self, model=None, dataset=None, config_file_list=None, config_dict=None + ): + """ + Args: + model (str/AbstractRecommender): the model name or the model class, default is None, if it is None, config + will search the parameter 'model' from the external input as the model name or model class. + dataset (str): the dataset name, default is None, if it is None, config will search the parameter 'dataset' + from the external input as the dataset name. + config_file_list (list of str): the external config file, it allows multiple config files, default is None. + config_dict (dict): the external parameter dictionaries, default is None. + """ + self.compatibility_settings() + self._init_parameters_category() + self.yaml_loader = self._build_yaml_loader() + self.file_config_dict = self._load_config_files(config_file_list) + self.variable_config_dict = self._load_variable_config_dict(config_dict) + self.cmd_config_dict = self._load_cmd_line() + self._merge_external_config_dict() + + self.model, self.model_class, self.dataset = self._get_model_and_dataset( + model, dataset + ) + self._load_internal_config_dict(self.model, self.model_class, self.dataset) + self.final_config_dict = self._get_final_config_dict() + self._set_default_parameters() + self._init_device() + + def _init_parameters_category(self): + self.parameters = dict() + self.parameters["General"] = general_arguments + self.parameters["Training"] = training_arguments + self.parameters["Evaluation"] = evaluation_arguments + + def _build_yaml_loader(self): + loader = yaml.FullLoader + loader.add_implicit_resolver( + "tag:yaml.org,2002:float", + re.compile( + """^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$""", + re.X, + ), + list("-+0123456789."), + ) + return loader + + def _convert_config_dict(self, config_dict): + """This function convert the str parameters to their original type.""" + for key in config_dict: + param = config_dict[key] + if not isinstance(param, str): + continue + try: + value = eval(param) + if value is not None and not isinstance( + value, (str, int, float, list, tuple, dict, bool) + ): + value = param + except (NameError, SyntaxError, TypeError): + if isinstance(param, str): + if param.lower() == "true": + value = True + elif param.lower() == "false": + value = False + else: + value = param + else: + value = param + config_dict[key] = value + return config_dict + + def _load_config_files(self, file_list): + file_config_dict = dict() + if file_list: + for file in file_list: + with open(file, "r", encoding="utf-8") as f: + file_config_dict.update( + yaml.load(f.read(), Loader=self.yaml_loader) + ) + return file_config_dict + + def _load_variable_config_dict(self, config_dict): + # HyperTuning may set the parameters such as mlp_hidden_size in NeuMF in the format of ['[]', '[]'] + # then config_dict will receive a str '[]', but indeed it's a list [] + # temporarily use _convert_config_dict to solve this problem + return self._convert_config_dict(config_dict) if config_dict else dict() + + def _load_cmd_line(self): + r"""Read parameters from command line and convert it to str.""" + cmd_config_dict = dict() + unrecognized_args = [] + if "ipykernel_launcher" not in sys.argv[0]: + for arg in sys.argv[1:]: + if not arg.startswith("--") or len(arg[2:].split("=")) != 2: + unrecognized_args.append(arg) + continue + cmd_arg_name, cmd_arg_value = arg[2:].split("=") + if ( + cmd_arg_name in cmd_config_dict + and cmd_arg_value != cmd_config_dict[cmd_arg_name] + ): + raise SyntaxError( + "There are duplicate commend arg '%s' with different value." + % arg + ) + else: + cmd_config_dict[cmd_arg_name] = cmd_arg_value + if len(unrecognized_args) > 0: + logger = getLogger() + logger.warning( + "command line args [{}] will not be used in RecBole".format( + " ".join(unrecognized_args) + ) + ) + cmd_config_dict = self._convert_config_dict(cmd_config_dict) + return cmd_config_dict + + def _merge_external_config_dict(self): + external_config_dict = dict() + external_config_dict.update(self.file_config_dict) + external_config_dict.update(self.variable_config_dict) + external_config_dict.update(self.cmd_config_dict) + self.external_config_dict = external_config_dict + + def _get_model_and_dataset(self, model, dataset): + if model is None: + try: + model = self.external_config_dict["model"] + except KeyError: + raise KeyError( + "model need to be specified in at least one of the these ways: " + "[model variable, config file, config dict, command line] " + ) + if not isinstance(model, str): + final_model_class = model + final_model = model.__name__ + else: + final_model = model + final_model_class = get_model(final_model) + + if dataset is None: + try: + final_dataset = self.external_config_dict["dataset"] + except KeyError: + raise KeyError( + "dataset need to be specified in at least one of the these ways: " + "[dataset variable, config file, config dict, command line] " + ) + else: + final_dataset = dataset + + return final_model, final_model_class, final_dataset + + def _update_internal_config_dict(self, file): + with open(file, "r", encoding="utf-8") as f: + config_dict = yaml.load(f.read(), Loader=self.yaml_loader) + if config_dict is not None: + self.internal_config_dict.update(config_dict) + return config_dict + def _load_internal_config_dict(self, model, model_class, dataset): + current_path = os.path.dirname(os.path.realpath(__file__)) + overall_init_file = os.path.join(current_path, "../properties/overall.yaml") + model_init_file = os.path.join( + current_path, "../properties/model/" + model + ".yaml" + ) + dataset_init_file = os.path.join( + current_path, "../properties/dataset/" + dataset + ".yaml" + ) + + + self.internal_config_dict = dict() + for file in [ + overall_init_file, + model_init_file, + dataset_init_file, + ]: + if os.path.isfile(file): + config_dict = self._update_internal_config_dict(file) + # if file == dataset_init_file: + # self.parameters["Dataset"] += [ + # key + # for key in config_dict.keys() + # if key not in self.parameters["Dataset"] + # ] + + def _get_final_config_dict(self): + final_config_dict = dict() + final_config_dict.update(self.internal_config_dict) + final_config_dict.update(self.external_config_dict) + return final_config_dict + + def _set_default_parameters(self): + self.final_config_dict["dataset"] = self.dataset + self.final_config_dict["model"] = self.model + + metrics = self.final_config_dict["metrics"] + if isinstance(metrics, str): + self.final_config_dict["metrics"] = [metrics] + + self.final_config_dict["checkpoint_dir"] = self.final_config_dict["checkpoint_dir"] + self.final_config_dict["checkpoint_sub_dir"] + + self.final_config_dict["temp_dir"] = self.final_config_dict['temp_dir'] + self.final_config_dict['temp_sub_dir'] + + def _init_device(self): + if isinstance(self.final_config_dict["gpu_id"], tuple): + self.final_config_dict["gpu_id"] = ",".join( + map(str, list(self.final_config_dict["gpu_id"])) + ) + else: + self.final_config_dict["gpu_id"] = str(self.final_config_dict["gpu_id"]) + gpu_id = self.final_config_dict["gpu_id"] + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id + + def __setitem__(self, key, value): + if not isinstance(key, str): + raise TypeError("index must be a str.") + self.final_config_dict[key] = value + + def __getattr__(self, item): + if "final_config_dict" not in self.__dict__: + raise AttributeError( + f"'Config' object has no attribute 'final_config_dict'" + ) + if item in self.final_config_dict: + return self.final_config_dict[item] + raise AttributeError(f"'Config' object has no attribute '{item}'") + + def __getitem__(self, item): + return self.final_config_dict.get(item) + + def __contains__(self, key): + if not isinstance(key, str): + raise TypeError("index must be a str.") + return key in self.final_config_dict + + def __str__(self): + args_info = "\n" + for category in self.parameters: + args_info += set_color(category + " Hyper Parameters:\n", "pink") + args_info += "\n".join( + [ + ( + set_color("{}", "cyan") + " =" + set_color(" {}", "yellow") + ).format(arg, value) + for arg, value in self.final_config_dict.items() + if arg in self.parameters[category] + ] + ) + args_info += "\n\n" + + args_info += set_color("Other Hyper Parameters: \n", "pink") + args_info += "\n".join( + [ + (set_color("{}", "cyan") + " = " + set_color("{}", "yellow")).format( + arg, value + ) + for arg, value in self.final_config_dict.items() + if arg + not in {_ for args in self.parameters.values() for _ in args}.union( + {"model", "dataset", "config_files"} + ) + ] + ) + args_info += "\n\n" + return args_info + + def __repr__(self): + return self.__str__() + + + def compatibility_settings(self): + import numpy as np + + np.bool = np.bool_ + np.int = np.int_ + np.float = np.float_ + np.complex = np.complex_ + np.object = np.object_ + np.str = np.str_ + np.long = np.int_ + np.unicode = np.unicode_ \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/data/__init__.py b/talkingface-toolkit-main/talkingface/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface-toolkit-main/talkingface/data/dataprocess/ADNeRFmaster_process.py b/talkingface-toolkit-main/talkingface/data/dataprocess/ADNeRFmaster_process.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface-toolkit-main/talkingface/data/dataprocess/__init__.py b/talkingface-toolkit-main/talkingface/data/dataprocess/__init__.py new file mode 100644 index 00000000..7c7aef87 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/data/dataprocess/__init__.py @@ -0,0 +1 @@ +from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio, Wav2LipPreprocessForInference \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/data/dataprocess/wav2lip_process.py b/talkingface-toolkit-main/talkingface/data/dataprocess/wav2lip_process.py new file mode 100644 index 00000000..e2279a42 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/data/dataprocess/wav2lip_process.py @@ -0,0 +1,253 @@ +import os +import cv2 +import numpy as np +import subprocess +from tqdm import tqdm +from glob import glob +from concurrent.futures import ThreadPoolExecutor, as_completed +from talkingface.utils import face_detection +import librosa +import librosa.filters +from scipy import signal +from scipy.io import wavfile + +class Wav2LipAudio: + """This class is used for audio processing of wav2lip + + 这个类提供了从音频到mel谱的方法 + """ + + def __init__(self, config): + self.config = config + # Conversions + self._mel_basis = None + + def load_wav(self, path, sr): + return librosa.core.load(path, sr=sr)[0] + + def save_wav(self, wav, path, sr): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + #proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + + def save_wavenet_wav(self, wav, path, sr): + librosa.output.write_wav(path, wav, sr=sr) + + def preemphasis(self, wav, k, preemphasize=True): + if preemphasize: + return signal.lfilter([1, -k], [1], wav) + return wav + + def inv_preemphasis(self, wav, k, inv_preemphasize=True): + if inv_preemphasize: + return signal.lfilter([1], [1, -k], wav) + return wav + + def get_hop_size(self): + hop_size = self.config['hop_size'] + if hop_size is None: + assert self.config['frame_shift_ms'] is not None + hop_size = int(self.config['frame_shift_ms'] / 1000 * self.config['sample_rate']) + return hop_size + + def linearspectrogram(self, wav): + D = self._stft(self.preemphasis(wav, self.config['preemphasis'], self.config['preemphasize'])) + S = self._amp_to_db(np.abs(D)) - self.config['ref_level_db'] + + if self.config['signal_normalization']: + return self._normalize(S) + return S + + def melspectrogram(self, wav): + D = self._stft(self.preemphasis(wav, self.config['preemphasis'], self.config['preemphasize'])) + S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - self.config['ref_level_db'] + + if self.config['signal_normalization']: + return self._normalize(S) + return S + + def _lws_processor(self): + import lws + return lws.lws(self.config['n_fft'], self.get_hop_size(), fftsize=self.config['win_size'], mode="speech") + + def _stft(self, y): + if self.config['use_lws']: + return self._lws_processor().stft(y).T + else: + return librosa.stft(y=y, n_fft=self.config['n_fft'], hop_length=self.get_hop_size(), win_length=self.config['win_size']) + + ########################################################## + #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) + def num_frames(self, length, fsize, fshift): + """Compute number of time frames of spectrogram + """ + pad = (fsize - fshift) + if length % fshift == 0: + M = (length + pad * 2 - fsize) // fshift + 1 + else: + M = (length + pad * 2 - fsize) // fshift + 2 + return M + + + def pad_lr(self, x, fsize, fshift): + """Compute left and right padding + """ + M = self.num_frames(len(x), fsize, fshift) + pad = (fsize - fshift) + T = len(x) + 2 * pad + r = (M - 1) * fshift + fsize - T + return pad, pad + r + ########################################################## + #Librosa correct padding + def librosa_pad_lr(self, x, fsize, fshift): + return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] + + def _linear_to_mel(self, spectogram): + if self._mel_basis is None: + self._mel_basis = self._build_mel_basis() + return np.dot(self._mel_basis, spectogram) + + def _build_mel_basis(self): + assert self.config['fmax'] <= self.config['sample_rate'] // 2 + return librosa.filters.mel(sr=self.config['sample_rate'], n_fft=self.config['n_fft'], n_mels=self.config['num_mels'], + fmin=self.config['fmin'], fmax=self.config['fmax']) + + def _amp_to_db(self, x): + min_level = np.exp(self.config['min_level_db'] / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + + def _db_to_amp(x): + return np.power(10.0, (x) * 0.05) + + def _normalize(self, S): + if self.config['allow_clipping_in_normalization']: + if self.config['symmetric_mels']: + return np.clip((2 * self.config['max_abs_value']) * ((S - self.config['min_level_db']) / (-self.config['min_level_db'])) - self.config['max_abs_value'], + -self.config['max_abs_value'], self.config['max_abs_value']) + else: + return np.clip(self.config['max_abs_value'] * ((S - self.config['min_level_db']) / (-self.config['min_level_db'])), 0, self.config['max_abs_value']) + + assert S.max() <= 0 and S.min() - self.config['min_level_db'] >= 0 + if self.config['symmetric_mels']: + return (2 * self.config['max_abs_value']) * ((S - self.config['min_level_db']) / (-self.config['min_level_db'])) - self.config['max_abs_value'] + else: + return self.config['max_abs_value'] * ((S - self.config['min_level_db']) / (-self.config['min_level_db'])) + + def _denormalize(self, D): + if self.config['allow_clipping_in_normalization']: + if self.config['symmetric_mels']: + return (((np.clip(D, -self.config['max_abs_value'], + self.config['max_abs_value']) + self.config['max_abs_value']) * -self.config['min_level_db'] / (2 * self.config['max_abs_value'])) + + self.config['symmetric_mels']) + else: + return ((np.clip(D, 0, self.config['max_abs_value']) * -self.config['min_level_db'] / self.config['max_abs_value']) + self.config['min_level_db']) + + if self.config['symmetric_mels']: + return (((D + self.config['max_abs_value']) * -self.config['min_level_db'] / (2 * self.config['max_abs_value'])) + self.config['min_level_db']) + else: + return ((D * -self.config['min_level_db'] / self.config['max_abs_value']) + self.config['min_level_db']) + + + +class Wav2LipPreprocessForInference: + """This class is used for preprocessing of wav2lip inference + + face_detect: detect face in the image + datagen: generate data for inference + + """ + def __init__(self, config): + self.config = config + self.fa = face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, + device=config['device']) + + def get_smoothened_boxes(self, boxes, T): + for i in range(len(boxes)): + if i + T > len(boxes): + window = boxes[len(boxes) - T:] + else: + window = boxes[i : i + T] + boxes[i] = np.mean(window, axis=0) + return boxes + + def face_detect(self, images): + batch_size = self.config['batch_size'] + while 1: + predictions = [] + try: + for i in tqdm(range(0, len(images), batch_size), desc='Running face detection', leave=False): + predictions.extend(self.fa.get_detections_for_batch(np.array(images[i:i + batch_size]))) + except RuntimeError: + if batch_size == 1: + raise RuntimeError('Image too big to run face detection on GPU. Please use the --resize_factor argument') + batch_size //= 2 + print('Recovering from OOM error; New batch size: {}'.format(batch_size)) + continue + break + results = [] + pady1, pady2, padx1, padx2 = self.config['pads'] + for rect, image in zip(predictions, images): + if rect is None: + cv2.imwrite('temp/faulty_frame.jpg', image) # check this frame where the face was not detected. + raise ValueError('Face not detected! Ensure the video contains a face in all the frames.') + + y1 = max(0, rect[1] - pady1) + y2 = min(image.shape[0], rect[3] + pady2) + x1 = max(0, rect[0] - padx1) + x2 = min(image.shape[1], rect[2] + padx2) + + results.append([x1, y1, x2, y2]) + boxes = np.array(results) + if not self.config['nosmooth']: boxes = self.get_smoothened_boxes(boxes, T=5) + results = [[image[y1: y2, x1:x2], (y1, y2, x1, x2)] for image, (x1, y1, x2, y2) in zip(images, boxes)] + + return results + + + def datagen(self, frames, face_det_results, mels): + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if self.config['box'][0] == -1: + if not self.config['static']: + face_det_results = self.face_detect(frames) # BGR2RGB for CNN face detection + else: + face_det_results = self.face_detect([frames[0]]) + else: + print('Using the specified bounding box instead of face detection...') + y1, y2, x1, x2 = self.config['box'] + face_det_results = [[f[y1: y2, x1:x2], (y1, y2, x1, x2)] for f in frames] + + for i, m in enumerate(mels): + idx = 0 if self.config['static'] else i%len(frames) + frame_to_save = frames[idx].copy() + face, coords = face_det_results[idx].copy() + + face = cv2.resize(face, (self.config['img_size'], self.config['img_size'])) + + img_batch.append(face) + mel_batch.append(m) + frame_batch.append(frame_to_save) + coords_batch.append(coords) + + if len(img_batch) >= self.config['wav2lip_batch_size']: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, self.config['img_size']//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch + img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] + + if len(img_batch) > 0: + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + + img_masked = img_batch.copy() + img_masked[:, self.config['img_size']//2:] = 0 + + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + yield img_batch, mel_batch, frame_batch, coords_batch \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/data/dataset/AD_NeRF_master_dataset.py b/talkingface-toolkit-main/talkingface/data/dataset/AD_NeRF_master_dataset.py new file mode 100644 index 00000000..155860b4 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/data/dataset/AD_NeRF_master_dataset.py @@ -0,0 +1,282 @@ +import cv2 +import numpy as np +import face_alignment +from skimage import io +import torch +import torch.nn.functional as F +import json +import os +from sklearn.neighbors import NearestNeighbors +from pathlib import Path +import argparse + + +def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, + device=euler_angle.device) + zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, + device=euler_angle.device) + rot_x = torch.cat(( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), 2) + rot_y = torch.cat(( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), 2) + rot_z = torch.cat(( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1) + ), 2) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + +parser = argparse.ArgumentParser() +parser.add_argument('--id', type=str, + default='obama', help='identity of target person') +parser.add_argument('--step', type=int, + default=0, help='step for running') + +args = parser.parse_args() +id = args.id +vid_file = os.path.join('dataset', 'vids', id + '.mp4') +if not os.path.isfile(vid_file): + print('no video') + exit() + +id_dir = os.path.join('dataset', id) +Path(id_dir).mkdir(parents=True, exist_ok=True) +ori_imgs_dir = os.path.join('dataset', id, 'ori_imgs') +Path(ori_imgs_dir).mkdir(parents=True, exist_ok=True) +parsing_dir = os.path.join(id_dir, 'parsing') +Path(parsing_dir).mkdir(parents=True, exist_ok=True) +head_imgs_dir = os.path.join('dataset', id, 'head_imgs') +Path(head_imgs_dir).mkdir(parents=True, exist_ok=True) +com_imgs_dir = os.path.join('dataset', id, 'com_imgs') +Path(com_imgs_dir).mkdir(parents=True, exist_ok=True) + +running_step = args.step + +# # Step 0: extract wav & deepspeech feature, better run in terminal to parallel with +# below commands since this may take a few minutes +if running_step == 0: + print('--- Step0: extract deepspeech feature ---') + wav_file = os.path.join(id_dir, 'aud.wav') + extract_wav_cmd = 'ffmpeg -i ' + vid_file + ' -f wav -ar 16000 ' + wav_file + os.system(extract_wav_cmd) + extract_ds_cmd = 'python data_util/deepspeech_features/extract_ds_features.py --input=' + id_dir + os.system(extract_ds_cmd) + exit() + +# Step 1: extract images +if running_step == 1: + print('--- Step1: extract images from vids ---') + cap = cv2.VideoCapture(vid_file) + frame_num = 0 + while (True): + _, frame = cap.read() + if frame is None: + break + cv2.imwrite(os.path.join(ori_imgs_dir, str(frame_num) + '.jpg'), frame) + frame_num = frame_num + 1 + cap.release() + exit() + +# Step 2: detect lands +if running_step == 2: + print('--- Step 2: detect landmarks ---') + fa = face_alignment.FaceAlignment( + face_alignment.LandmarksType._2D, flip_input=False) + for image_path in os.listdir(ori_imgs_dir): + if image_path.endswith('.jpg'): + input = io.imread(os.path.join(ori_imgs_dir, image_path))[:, :, :3] + preds = fa.get_landmarks(input) + if len(preds) > 0: + lands = preds[0].reshape(-1, 2)[:, :2] + np.savetxt(os.path.join(ori_imgs_dir, image_path[:-3] + 'lms'), lands, '%f') + +max_frame_num = 100000 +valid_img_ids = [] +for i in range(max_frame_num): + if os.path.isfile(os.path.join(ori_imgs_dir, str(i) + '.lms')): + valid_img_ids.append(i) +valid_img_num = len(valid_img_ids) +tmp_img = cv2.imread(os.path.join(ori_imgs_dir, str(valid_img_ids[0]) + '.jpg')) +h, w = tmp_img.shape[0], tmp_img.shape[1] + +# Step 3: face parsing +if running_step == 3: + print('--- Step 3: face parsing ---') + face_parsing_cmd = 'python data_util/face_parsing/test.py --respath=dataset/' + \ + id + '/parsing --imgpath=dataset/' + id + '/ori_imgs' + os.system(face_parsing_cmd) + +# Step 4: extract bc image +if running_step == 4: + print('--- Step 4: extract background image ---') + sel_ids = np.array(valid_img_ids)[np.arange(0, valid_img_num, 20)] + all_xys = np.mgrid[0:h, 0:w].reshape(2, -1).transpose() + distss = [] + for i in sel_ids: + parse_img = cv2.imread(os.path.join(id_dir, 'parsing', str(i) + '.png')) + bg = (parse_img[..., 0] == 255) & ( + parse_img[..., 1] == 255) & (parse_img[..., 2] == 255) + fg_xys = np.stack(np.nonzero(~bg)).transpose(1, 0) + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) + dists, _ = nbrs.kneighbors(all_xys) + distss.append(dists) + distss = np.stack(distss) + print(distss.shape) + max_dist = np.max(distss, 0) + max_id = np.argmax(distss, 0) + bc_pixs = max_dist > 5 + bc_pixs_id = np.nonzero(bc_pixs) + bc_ids = max_id[bc_pixs] + imgs = [] + num_pixs = distss.shape[1] + for i in sel_ids: + img = cv2.imread(os.path.join(ori_imgs_dir, str(i) + '.jpg')) + imgs.append(img) + imgs = np.stack(imgs).reshape(-1, num_pixs, 3) + bc_img = np.zeros((h * w, 3), dtype=np.uint8) + bc_img[bc_pixs_id, :] = imgs[bc_ids, bc_pixs_id, :] + bc_img = bc_img.reshape(h, w, 3) + max_dist = max_dist.reshape(h, w) + bc_pixs = max_dist > 5 + bg_xys = np.stack(np.nonzero(~bc_pixs)).transpose() + fg_xys = np.stack(np.nonzero(bc_pixs)).transpose() + nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree').fit(fg_xys) + distances, indices = nbrs.kneighbors(bg_xys) + bg_fg_xys = fg_xys[indices[:, 0]] + print(fg_xys.shape) + print(np.max(bg_fg_xys), np.min(bg_fg_xys)) + bc_img[bg_xys[:, 0], bg_xys[:, 1], + :] = bc_img[bg_fg_xys[:, 0], bg_fg_xys[:, 1], :] + cv2.imwrite(os.path.join(id_dir, 'bc.jpg'), bc_img) + +# Step 5: save training images +if running_step == 5: + print('--- Step 5: save training images ---') + bc_img = cv2.imread(os.path.join(id_dir, 'bc.jpg')) + for i in valid_img_ids: + parsing_img = cv2.imread(os.path.join(parsing_dir, str(i) + '.png')) + head_part = (parsing_img[:, :, 0] == 255) & ( + parsing_img[:, :, 1] == 0) & (parsing_img[:, :, 2] == 0) + bc_part = (parsing_img[:, :, 0] == 255) & ( + parsing_img[:, :, 1] == 255) & (parsing_img[:, :, 2] == 255) + img = cv2.imread(os.path.join(ori_imgs_dir, str(i) + '.jpg')) + img[bc_part] = bc_img[bc_part] + cv2.imwrite(os.path.join(com_imgs_dir, str(i) + '.jpg'), img) + img[~head_part] = bc_img[~head_part] + cv2.imwrite(os.path.join(head_imgs_dir, str(i) + '.jpg'), img) + +# Step 6: estimate head pose +if running_step == 6: + print('--- Estimate Head Pose ---') + est_pose_cmd = 'python data_util/face_tracking/face_tracker.py --idname=' + \ + id + ' --img_h=' + str(h) + ' --img_w=' + str(w) + \ + ' --frame_num=' + str(max_frame_num) + os.system(est_pose_cmd) + exit() + +# Step 7: save transform param & write config file +if running_step == 7: + print('--- Step 7: Save Transform Param ---') + params_dict = torch.load(os.path.join(id_dir, 'track_params.pt')) + focal_len = params_dict['focal'] + euler_angle = params_dict['euler'] + trans = params_dict['trans'] / 10.0 + valid_num = euler_angle.shape[0] + train_val_split = int(valid_num * 10 / 11) + train_ids = torch.arange(0, train_val_split) + val_ids = torch.arange(train_val_split, valid_num) + rot = euler2rot(euler_angle) + rot_inv = rot.permute(0, 2, 1) + trans_inv = -torch.bmm(rot_inv, trans.unsqueeze(2)) + pose = torch.eye(4, dtype=torch.float32) + save_ids = ['train', 'val'] + train_val_ids = [train_ids, val_ids] + mean_z = -float(torch.mean(trans[:, 2]).item()) + for i in range(2): + transform_dict = dict() + transform_dict['focal_len'] = float(focal_len[0]) + transform_dict['cx'] = float(w / 2.0) + transform_dict['cy'] = float(h / 2.0) + transform_dict['frames'] = [] + ids = train_val_ids[i] + save_id = save_ids[i] + for i in ids: + i = i.item() + frame_dict = dict() + frame_dict['img_id'] = int(valid_img_ids[i]) + frame_dict['aud_id'] = int(valid_img_ids[i]) + pose[:3, :3] = rot_inv[i] + pose[:3, 3] = trans_inv[i, :, 0] + frame_dict['transform_matrix'] = pose.numpy().tolist() + lms = np.loadtxt(os.path.join( + ori_imgs_dir, str(valid_img_ids[i]) + '.lms')) + min_x, max_x = np.min(lms, 0)[0], np.max(lms, 0)[0] + cx = int((min_x + max_x) / 2.0) + cy = int(lms[27, 1]) + h_w = int((max_x - cx) * 1.5) + h_h = int((lms[8, 1] - cy) * 1.15) + rect_x = cx - h_w + rect_y = cy - h_h + if rect_x < 0: + rect_x = 0 + if rect_y < 0: + rect_y = 0 + rect_w = min(w - 1 - rect_x, 2 * h_w) + rect_h = min(h - 1 - rect_y, 2 * h_h) + rect = np.array((rect_x, rect_y, rect_w, rect_h), dtype=np.int32) + frame_dict['face_rect'] = rect.tolist() + transform_dict['frames'].append(frame_dict) + with open(os.path.join(id_dir, 'transforms_' + save_id + '.json'), 'w') as fp: + json.dump(transform_dict, fp, indent=2, separators=(',', ': ')) + + dir_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + testskip = int(val_ids.shape[0] / 7) + + HeadNeRF_config_file = os.path.join(id_dir, 'HeadNeRF_config.txt') + with open(HeadNeRF_config_file, 'w') as file: + file.write('expname = ' + id + '_head\n') + file.write('datadir = ' + os.path.join(dir_path, 'dataset', id) + '\n') + file.write('basedir = ' + os.path.join(dir_path, + 'dataset', id, 'logs') + '\n') + file.write('near = ' + str(mean_z - 0.2) + '\n') + file.write('far = ' + str(mean_z + 0.4) + '\n') + file.write('testskip = ' + str(testskip) + '\n') + Path(os.path.join(dir_path, 'dataset', id, 'logs', id + '_head') + ).mkdir(parents=True, exist_ok=True) + + ComNeRF_config_file = os.path.join(id_dir, 'TorsoNeRF_config.txt') + with open(ComNeRF_config_file, 'w') as file: + file.write('expname = ' + id + '_com\n') + file.write('datadir = ' + os.path.join(dir_path, 'dataset', id) + '\n') + file.write('basedir = ' + os.path.join(dir_path, + 'dataset', id, 'logs') + '\n') + file.write('near = ' + str(mean_z - 0.2) + '\n') + file.write('far = ' + str(mean_z + 0.4) + '\n') + file.write('testskip = ' + str(testskip) + '\n') + Path(os.path.join(dir_path, 'dataset', id, 'logs', id + '_com') + ).mkdir(parents=True, exist_ok=True) + + ComNeRFTest_config_file = os.path.join(id_dir, 'TorsoNeRFTest_config.txt') + with open(ComNeRFTest_config_file, 'w') as file: + file.write('expname = ' + id + '_com\n') + file.write('datadir = ' + os.path.join(dir_path, 'dataset', id) + '\n') + file.write('basedir = ' + os.path.join(dir_path, + 'dataset', id, 'logs') + '\n') + file.write('near = ' + str(mean_z - 0.2) + '\n') + file.write('far = ' + str(mean_z + 0.4) + '\n') + file.write('with_test = ' + str(1) + '\n') + file.write('test_pose_file = transforms_val.json' + '\n') + + print(id + ' data processed done!') diff --git a/talkingface-toolkit-main/talkingface/data/dataset/__init__.py b/talkingface-toolkit-main/talkingface/data/dataset/__init__.py new file mode 100644 index 00000000..4d598341 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/data/dataset/__init__.py @@ -0,0 +1,3 @@ +from talkingface.data.dataset.dataset import Dataset +from talkingface.data.dataset.meta_portrait_base_dataset import * +# from talkingface.data.dataset.wav2lip_dataset import Wav2LipDataset diff --git a/talkingface-toolkit-main/talkingface/data/dataset/dataset.py b/talkingface-toolkit-main/talkingface/data/dataset/dataset.py new file mode 100644 index 00000000..2de27bd7 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/data/dataset/dataset.py @@ -0,0 +1,25 @@ +import torch + +class Dataset(torch.utils.data.Dataset): + def __init__(self, config, datasplit): + + """ + args: datasplit: str, 'train', 'val' or 'test'(这个参数必须要有, 提前将数据集划分为train, val和test三个部分, + 具体参数形式可以自己定,只要在你的dataset子类中可以获取到数据就可以, + 对应的配置文件的参数为:train_filelist, val_filelist和test_filelist) + + """ + + self.config = config + self.split = datasplit + + def __getitem__(self): + + """ + Returns: + data: dict, 必须是一个字典格式, 具体数据解析在model文件里解析 + + """ + + + raise NotImplementedError \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/data/dataset/wav2lip_dataset.py b/talkingface-toolkit-main/talkingface/data/dataset/wav2lip_dataset.py new file mode 100644 index 00000000..e52ae28d --- /dev/null +++ b/talkingface-toolkit-main/talkingface/data/dataset/wav2lip_dataset.py @@ -0,0 +1,158 @@ +from os.path import dirname, join, basename, isfile +from tqdm import tqdm +from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio +import python_speech_features + +import torch +from torch import nn +from torch import optim +import torch.backends.cudnn as cudnn +from torch.utils import data as data_utils +import numpy as np +import librosa +from talkingface.data.dataset.dataset import Dataset + +from glob import glob + +import os, random, cv2, argparse + +class Wav2LipDataset(Dataset): + def __init__(self, config, datasplit): + super().__init__(config, datasplit) + self.all_videos = self.get_image_list(self.config['preprocessed_root'], datasplit) + self.audio_processor = Wav2LipAudio(self.config) + + + def get_image_list(self, data_root, split_path): + filelist = [] + + with open(split_path) as f: + for line in f: + line = line.strip() + if ' ' in line: line = line.split()[0] + filelist.append(os.path.join(data_root, line)) + + return filelist + + + def get_frame_id(self, frame): + return int(basename(frame).split('.')[0]) + + def get_window(self, start_frame): + start_id = self.get_frame_id(start_frame) + vidname = dirname(start_frame) + + window_fnames = [] + for frame_id in range(start_id, start_id + self.config['syncnet_T']): + frame = join(vidname, '{}.jpg'.format(frame_id)) + if not isfile(frame): + return None + window_fnames.append(frame) + return window_fnames + + def read_window(self, window_fnames): + if window_fnames is None: return None + window = [] + for fname in window_fnames: + img = cv2.imread(fname) + if img is None: + return None + try: + img = cv2.resize(img, (self.config['img_size'], self.config['img_size'])) + except Exception as e: + return None + + window.append(img) + + return window + + def crop_audio_window(self, spec, start_frame): + if type(start_frame) == int: + start_frame_num = start_frame + else: + start_frame_num = self.get_frame_id(start_frame) # 0-indexing ---> 1-indexing + start_idx = int(80. * (start_frame_num / float(self.config['fps']))) + + end_idx = start_idx + self.config['syncnet_mel_step_size'] + + return spec[start_idx : end_idx, :] + + def get_segmented_mels(self, spec, start_frame): + mels = [] + assert self.config['syncnet_T'] == 5 + start_frame_num = self.get_frame_id(start_frame) + 1 # 0-indexing ---> 1-indexing + if start_frame_num - 2 < 0: return None + for i in range(start_frame_num, start_frame_num + self.config['syncnet_T']): + m = self.crop_audio_window(spec, i - 2) + if m.shape[0] != self.config['syncnet_mel_step_size']: + return None + mels.append(m.T) + + mels = np.asarray(mels) + + return mels + + def prepare_window(self, window): + # 3 x T x H x W + x = np.asarray(window) / 255. + x = np.transpose(x, (3, 0, 1, 2)) + + return x + + def __len__(self): + return len(self.all_videos) + + def __getitem__(self, idx): + while 1: + idx = random.randint(0, len(self.all_videos) - 1) + vidname = self.all_videos[idx] + img_names = list(glob(join(vidname, '*.jpg'))) + if len(img_names) <= 3 * self.config['syncnet_T']: + continue + + img_name = random.choice(img_names) + wrong_img_name = random.choice(img_names) + while wrong_img_name == img_name: + wrong_img_name = random.choice(img_names) + + window_fnames = self.get_window(img_name) + wrong_window_fnames = self.get_window(wrong_img_name) + if window_fnames is None or wrong_window_fnames is None: + continue + + window = self.read_window(window_fnames) + if window is None: + continue + + wrong_window = self.read_window(wrong_window_fnames) + if wrong_window is None: + continue + + try: + wavpath = join(vidname, "audio.wav") + wav = self.audio_processor.load_wav(wavpath, self.config['sample_rate']) + orig_mel = self.audio_processor.melspectrogram(wav).T + except Exception as e: + continue + + mel = self.crop_audio_window(orig_mel.copy(), img_name) + + if (mel.shape[0] != self.config['syncnet_mel_step_size']): + continue + + indiv_mels = self.get_segmented_mels(orig_mel.copy(), img_name) + if indiv_mels is None: continue + + window = self.prepare_window(window) + y = window.copy() + window[:, :, window.shape[2]//2:] = 0. + + wrong_window = self.prepare_window(wrong_window) + x = np.concatenate([window, wrong_window], axis=0) + + x = torch.FloatTensor(x) + mel = torch.FloatTensor(mel.T).unsqueeze(0) + # breakpoint() + indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1) + y = torch.FloatTensor(y) + return {"input_frames":x, "indiv_mels":indiv_mels, "mels":mel, "gt":y} diff --git a/talkingface-toolkit-main/talkingface/evaluator/__init__.py b/talkingface-toolkit-main/talkingface/evaluator/__init__.py new file mode 100644 index 00000000..81d4f2ee --- /dev/null +++ b/talkingface-toolkit-main/talkingface/evaluator/__init__.py @@ -0,0 +1,5 @@ +from talkingface.evaluator.metric_models import * +from talkingface.evaluator.metrics import * +from talkingface.evaluator.register import * +from talkingface.evaluator.evaluator import * +from talkingface.evaluator.meta_portrait_base_inference import * \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/evaluator/base_metric.py b/talkingface-toolkit-main/talkingface/evaluator/base_metric.py new file mode 100644 index 00000000..81e15721 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/evaluator/base_metric.py @@ -0,0 +1,98 @@ +import torch +from talkingface.utils import EvaluatorType + +class AbstractMetric(object): + """:class:`AbstractMetric` is the base object of all metrics. If you want to + implement a metric, you should inherit this class. + + Args: + config (Config): the config of evaluator. + """ + + smaller = False + + def __init__(self, config): + self.decimal_place = config["metric_decimal_place"] + + def calculate_metric(self, dataobject): + """Get the dictionary of a metric. + + Args: + dataobject: (dict): it contains all the information needed to calculate metrics. + + Returns: + dict: such as ``{LSE-C': 0.0000}`` + """ + raise NotImplementedError("Method [calculate_metric] should be implemented.") + + +class SyncMetric(AbstractMetric): + """Base class for all Sync metrics. If you want to implement a sync metric, you can inherit this class. + """ + + metric_type = EvaluatorType.SYNC + metric_need = ["generated_video"] + def __init__(self, config): + super(SyncMetric, self).__init__(config) + + def get_videolist(self, dataobject): + """Get the list of videos. + + Args: + dataobject(DataStruct): (dict): it contains all the information needed to calculate metrics. + + Returns: + list: a list of videos. + """ + return dataobject["generated_video"] + + def metric_info(self, dataobject): + """Calculate the value of the metric. + + Args: + dataobject(DataStruct): it contains all the information needed to calculate metrics. + + Returns: + dict: {"LSE-C": LSE_C, "LSE-D": LSE_D} + """ + raise NotImplementedError("Method [metric_info] should be implemented.") + +class VideoQMetric(AbstractMetric): + """Base class for all Video Quality metrics. If you want to implement a Video Quality metric, you can inherit this class. + """ + + metric_type = EvaluatorType.VIDEOQ + def __init__(self, config): + super(VideoQMetric, self).__init__(config) + + def get_videopair(self, dataobject): + return list(zip(dataobject["generated_video"], dataobject["real_video"])) + + def metric_info(self, dataobject): + """Calculate the value of the metric. + + Args: + dataobject(DataStruct): (dict): it contains all the information needed to calculate metrics. + + Returns: + float: the value of the metric. + """ + raise NotImplementedError("Method [metric_info] should be implemented.") + +# class AudioQMetric(AbstractMetric): +# """Base class for all Audio Quality metrics. If you want to implement a Audio Quality metric, you can inherit this class. +# """ +# def __init__(self, config): +# super(SyncMetric, self).__init__(config) + +# def metric_info(self, dataobject): +# """Calculate the value of the metric. + +# Args: +# dataobject(DataStruct): it contains all the information needed to calculate metrics. + +# Returns: +# float: the value of the metric. +# """ +# raise NotImplementedError("Method [metric_info] should be implemented.") + diff --git a/talkingface-toolkit-main/talkingface/evaluator/evaluator.py b/talkingface-toolkit-main/talkingface/evaluator/evaluator.py new file mode 100644 index 00000000..d2592bee --- /dev/null +++ b/talkingface-toolkit-main/talkingface/evaluator/evaluator.py @@ -0,0 +1,22 @@ +from talkingface.evaluator.register import metrics_dict + +class Evaluator(object): + def __init__(self, config): + self.config = config + self.metrics = [metric.lower() for metric in self.config["metrics"]] + self.metric_class = {} + + for metric in self.metrics: + if metric not in metrics_dict: + raise ValueError(f"Metric '{metric}' is not defined.") + self.metric_class[metric] = metrics_dict[metric](self.config) + + def evaluate(self, datadict): + + result_dict = {} + + for metric in self.metrics: + metric_val = self.metric_class[metric].calculate_metric(datadict) + result_dict[metric] = metric_val + + return result_dict diff --git a/talkingface-toolkit-main/talkingface/evaluator/metric_models.py b/talkingface-toolkit-main/talkingface/evaluator/metric_models.py new file mode 100644 index 00000000..3a38ccba --- /dev/null +++ b/talkingface-toolkit-main/talkingface/evaluator/metric_models.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn + +class S(nn.Module): + def __init__(self, num_layers_in_fc_layers = 1024): + super(S, self).__init__() + + self.__nFeatures__ = 24 + self.__nChs__ = 32 + self.__midChs__ = 32 + + self.netcnnaud = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(1,1), stride=(1,1)), + + nn.Conv2d(64, 192, kernel_size=(3,3), stride=(1,1), padding=(1,1)), + nn.BatchNorm2d(192), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(3,3), stride=(1,2)), + + nn.Conv2d(192, 384, kernel_size=(3,3), padding=(1,1)), + nn.BatchNorm2d(384), + nn.ReLU(inplace=True), + + nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + + nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)), + + nn.Conv2d(256, 512, kernel_size=(5,4), padding=(0,0)), + nn.BatchNorm2d(512), + nn.ReLU(), + ) + + self.netfcaud = nn.Sequential( + nn.Linear(512, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, num_layers_in_fc_layers), + ) + + self.netfclip = nn.Sequential( + nn.Linear(512, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, num_layers_in_fc_layers), + ) + + self.netcnnlip = nn.Sequential( + nn.Conv3d(3, 96, kernel_size=(5,7,7), stride=(1,2,2), padding=0), + nn.BatchNorm3d(96), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)), + + nn.Conv3d(96, 256, kernel_size=(1,5,5), stride=(1,2,2), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)), + + nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + + nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + + nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)), + + nn.Conv3d(256, 512, kernel_size=(1,6,6), padding=0), + nn.BatchNorm3d(512), + nn.ReLU(inplace=True), + ) + + def forward_aud(self, x): + mid = self.netcnnaud(x); # N x ch x 24 x M + mid = mid.view((mid.size()[0], -1)); # N x (ch x 24) + out = self.netfcaud(mid) + + return out + + def forward_lip(self, x): + + mid = self.netcnnlip(x); + mid = mid.view((mid.size()[0], -1)); # N x (ch x 24) + out = self.netfclip(mid) + + return out + + def forward_lipfeat(self, x): + + mid = self.netcnnlip(x) + out = mid.view((mid.size()[0], -1)); # N x (ch x 24) + + return out + \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/evaluator/metrics.py b/talkingface-toolkit-main/talkingface/evaluator/metrics.py new file mode 100644 index 00000000..eb2692c7 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/evaluator/metrics.py @@ -0,0 +1,243 @@ +import torch +import numpy +import time, pdb, argparse, subprocess, os, math, glob +import cv2 +import python_speech_features +from tqdm import tqdm +from scipy import signal +from scipy.io import wavfile +from talkingface.evaluator.metric_models import * +from shutil import rmtree +from skimage.metrics import structural_similarity as ssim +from talkingface.evaluator.base_metric import AbstractMetric, SyncMetric, VideoQMetric +from talkingface.utils.logger import set_color + +class LSE(SyncMetric): + + ''' + ''' + def __init__(self, config, num_layers_in_fc_layers = 1024): + super(LSE, self).__init__(config) + self.config = config + self.syncnet = S(num_layers_in_fc_layers = num_layers_in_fc_layers) + + def metric_info(self, videofile): + self.loadParameters(self.config['lse_checkpoint_path']) + self.syncnet.to(self.config["device"]) + self.syncnet.eval() + + if os.path.exists(os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'])): + rmtree(os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'])) + + os.makedirs(os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'])) + + command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'],'%06d.jpg'))) + output = subprocess.call(command, shell=True, stdout=None) + + command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'],'audio.wav'))) + output = subprocess.call(command, shell=True, stdout=None) + + images = [] + + flist = glob.glob(os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'],'*.jpg')) + flist.sort() + + for fname in flist: + img_input = cv2.imread(fname) + img_input = cv2.resize(img_input, (224, 224)) + images.append(img_input) + + im = numpy.stack(images, axis=3) + im = numpy.expand_dims(im, axis=0) + im = numpy.transpose(im, (0, 3, 4, 1, 2)) + + imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) + + sample_rate, audio = wavfile.read(os.path.join(self.config['temp_dir'], self.config['lse_reference_dir'],'audio.wav')) + mfcc = zip(*python_speech_features.mfcc(audio,sample_rate)) + mfcc = numpy.stack([numpy.array(i) for i in mfcc]) + + cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0) + cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float()) + + # ========== ========== + # Check audio and video input length + # ========== ========== + + #if (float(len(audio))/16000) != (float(len(images))/25) : + # print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25)) + + min_length = min(len(images),math.floor(len(audio)/640)) + + # ========== ========== + # Generate video and audio feats + # ========== ========== + + lastframe = min_length-5 + im_feat = [] + cc_feat = [] + + tS = time.time() + for i in range(0,lastframe,self.config['evaluate_batch_size']): + + im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe, i+self.config['evaluate_batch_size'])) ] + im_in = torch.cat(im_batch,0) + im_out = self.syncnet.forward_lip(im_in.cuda()) + im_feat.append(im_out.data.cpu()) + + cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe, i+self.config['evaluate_batch_size'])) ] + cc_in = torch.cat(cc_batch,0) + cc_out = self.syncnet.forward_aud(cc_in.cuda()) + cc_feat.append(cc_out.data.cpu()) + + im_feat = torch.cat(im_feat,0) + cc_feat = torch.cat(cc_feat,0) + + # ========== ========== + # Compute offset + # ========== ========== + + #print('Compute time %.3f sec.' % (time.time()-tS)) + + dists = self.calc_pdist(im_feat,cc_feat,vshift=self.config['vshift']) + mdist = torch.mean(torch.stack(dists,1),1) + + minval, minidx = torch.min(mdist,0) + + offset = self.config['vshift']-minidx + conf = torch.median(mdist) - minval + + fdist = numpy.stack([dist[minidx].numpy() for dist in dists]) + # fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15) + fconf = torch.median(mdist).numpy() - fdist + fconfm = signal.medfilt(fconf,kernel_size=9) + + numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format}) + + return conf.numpy(), minval.numpy() + + + + def calc_pdist(self, feat1, feat2, vshift=10): + + win_size = vshift*2+1 + + feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift)) + + dists = [] + + for i in range(0,len(feat1)): + + dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:])) + + return dists + + def calculate_metric(self, dataobject): + video_list = self.get_videolist(dataobject) + + LSE_dict = {} + + iter_data = ( + tqdm( + video_list, + total=len(video_list), + desc=set_color("calculate for lip-audio sync", "yellow") + ) + if self.config['show_progress'] + else video_list + ) + + for video in iter_data: + LSE_C, LSE_D = self.metric_info(video) + if "LSE_C" not in LSE_dict: + LSE_dict["LSE_C"] = [LSE_C] + else: + LSE_dict["LSE_C"].append(LSE_C) + if "LSE_D" not in LSE_dict: + LSE_dict["LSE_D"] = [LSE_D] + else: + LSE_dict["LSE_D"].append(LSE_D) + + return {"LSE-C: {}".format(sum(LSE_dict["LSE_C"])/len(LSE_dict["LSE_C"])), "LSE-D: {}".format(sum(LSE_dict["LSE_D"])/len(LSE_dict["LSE_D"]))} + + def loadParameters(self, path): + loaded_state = torch.load(path, map_location=lambda storage, loc: storage) + + self_state = self.syncnet.state_dict() + + for name, param in loaded_state.items(): + + self_state[name].copy_(param) + + +class SSIM(VideoQMetric): + + metric_need = ["generated_video", "real_video"] + + def __init__(self, config): + super(SSIM, self).__init__(config) + self.config = config + + def metric_info(self, g_videofile, r_videofile): + g_frames = [] + g_video = cv2.VideoCapture(g_videofile) + while g_video.isOpened(): + ret, frame = g_video.read() + if not ret: + break + frame = cv2.resize(frame, (224, 224)) + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + g_frames.append(gray) + g_video.release() + + r_frames = [] + r_video = cv2.VideoCapture(r_videofile) + while r_video.isOpened(): + ret, frame = r_video.read() + if not ret: + break + frame = cv2.resize(frame, (224, 224)) + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + r_frames.append(gray) + r_video.release() + + min_frames = min(len(g_frames), len(r_frames)) + + g_frames = g_frames[:min_frames] + r_frames = r_frames[:min_frames] + + ssim_scores = [] + for frame1, frame2 in zip(g_frames, r_frames): + # 计算两帧之间的SSIM + score, _ = ssim(frame1, frame2, full=True) + ssim_scores.append(score) + + return numpy.mean(ssim_scores) + + + + + + + def calculate_metric(self, dataobject): + + pair_list = self.get_videopair(dataobject) + + ssim_score_total = [] + + iter_data = ( + tqdm( + pair_list, + total=len(pair_list), + desc=set_color("calculate for video quality ssim", "yellow") + ) + if self.config['show_progress'] + else pair_list + ) + for pair in iter_data: + g_video = pair[0] + r_video = pair[1] + ssim_score = self.metric_info(g_video, r_video) + ssim_score_total.append(ssim_score) + + return sum(ssim_score_total)/len(ssim_score_total) diff --git a/talkingface-toolkit-main/talkingface/evaluator/register.py b/talkingface-toolkit-main/talkingface/evaluator/register.py new file mode 100644 index 00000000..2ff33c3b --- /dev/null +++ b/talkingface-toolkit-main/talkingface/evaluator/register.py @@ -0,0 +1,46 @@ +import inspect +import sys + +def cluster_info(module_name): + """Collect information of all metrics, including: + + - ``metric_need``: Information needed to calculate this metric, the combination of ``rec.items, rec.topk, + rec.meanrank, rec.score, data.num_items, data.num_users, data.count_items, data.count_users, data.label``. + - ``metric_type``: Whether the scores required by metric are grouped by user, range in ``EvaluatorType.RANKING`` + and ``EvaluatorType.VALUE``. + - ``smaller``: Whether the smaller metric value represents better performance, + range in ``True`` and ``False``, default to ``False``. + + Args: + module_name (str): the name of module ``recbole.evaluator.metrics``. + + Returns: + dict: Three dictionaries containing the above information + and a dictionary matching metric names to metric classes. + """ + smaller_m = [] + m_dict, m_info, m_types = {}, {}, {} + metric_class = inspect.getmembers( + sys.modules[module_name], + lambda x: inspect.isclass(x) and x.__module__ == module_name, + ) + for name, metric_cls in metric_class: + name = name.lower() + m_dict[name] = metric_cls + if hasattr(metric_cls, "metric_need"): + m_info[name] = metric_cls.metric_need + else: + raise AttributeError(f"Metric '{name}' has no attribute [metric_need].") + if hasattr(metric_cls, "metric_type"): + m_types[name] = metric_cls.metric_type + else: + raise AttributeError(f"Metric '{name}' has no attribute [metric_type].") + if metric_cls.smaller is True: + smaller_m.append(name) + return smaller_m, m_info, m_types, m_dict + + +metric_module_name = "talkingface.evaluator.metrics" +smaller_metrics, metric_information, metric_types, metrics_dict = cluster_info( + metric_module_name +) \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/model/__init__.py b/talkingface-toolkit-main/talkingface/model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface-toolkit-main/talkingface/model/abstract_speech.py b/talkingface-toolkit-main/talkingface/model/abstract_speech.py new file mode 100644 index 00000000..49d6a858 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/abstract_speech.py @@ -0,0 +1,75 @@ +from logging import getLogger + +import torch +import torch.nn as nn +import numpy as np +from talkingface.utils import set_color + +class AbstractSpeech(nn.Module): + """Abstract class for talking face model.""" + + def __init__(self): + self.logger = getLogger() + super(AbstractSpeech, self).__init__() + + def calculate_loss(self, interaction): + r"""Calculate the training loss for a batch data. + + Args: + interaction (Interaction): Interaction class of the batch. + + Returns: + dict: {"loss": loss, "xxx": xxx} + 返回是一个字典,loss 这个键必须有,它代表了加权之后的总loss。 + 因为有时总loss可能由多个部分组成。xxx代表其它各部分loss + """ + raise NotImplementedError + + def predict(self, interaction): + r"""Predict the scores between users and items. + + Args: + interaction (Interaction): Interaction class of the batch. + + Returns: + video/image numpy/tensor + """ + raise NotImplementedError + + def generate_batch(): + + """ + 根据划分的test_filelist 批量生成数据。 + + Returns: dict: {"generated_audio": [generated_audio], "real_audio": [real_audio] } + 必须是一个字典数据, 且字典的键一个时generated_audio, 一个是real_audio,值都是列表, + 分别对应生成的音频和真实的音频。且两个列表的长度应该相同。 + 即每个生成音频都有对应的真实音频(或近似对应的音频)。 + """ + + raise NotImplementedError + + def other_parameter(self): + if hasattr(self, "other_parameter_name"): + return {key: getattr(self, key) for key in self.other_parameter_name} + return dict() + + def load_other_parameter(self, para): + if para is None: + return + for key, value in para.items(): + setattr(self, key, value) + + def __str__(self): + """ + Model prints with number of trainable parameters + """ + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return ( + super().__str__() + + set_color("\nTrainable parameters", "blue") + + f": {params}" + ) + + diff --git a/talkingface-toolkit-main/talkingface/model/abstract_talkingface.py b/talkingface-toolkit-main/talkingface/model/abstract_talkingface.py new file mode 100644 index 00000000..14810e0d --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/abstract_talkingface.py @@ -0,0 +1,75 @@ +from logging import getLogger + +import torch +import torch.nn as nn +import numpy as np +from talkingface.utils import set_color + +class AbstractTalkingFace(nn.Module): + """Abstract class for talking face model.""" + + def __init__(self): + self.logger = getLogger() + super(AbstractTalkingFace, self).__init__() + + def calculate_loss(self, interaction): + r"""Calculate the training loss for a batch data. + + Args: + interaction (Interaction): Interaction class of the batch. + + Returns: + dict: {"loss": loss, "xxx": xxx} + 返回是一个字典,loss 这个键必须有,它代表了加权之后的总loss。 + 因为有时总loss可能由多个部分组成。xxx代表其它各部分loss + """ + raise NotImplementedError + + def predict(self, interaction): + r"""Predict the scores between users and items. + + Args: + interaction (Interaction): Interaction class of the batch. + + Returns: + video/image numpy/tensor + """ + raise NotImplementedError + + def generate_batch(): + + """ + 根据划分的test_filelist 批量生成数据。 + + Returns: dict: {"generated_video": [generated_video], "real_video": [real_video] } + 必须是一个字典数据, 且字典的键一个时generated_video, 一个是real_video,值都是列表, + 分别对应生成的视频和真实的视频。且两个列表的长度应该相同。 + 即每个生成视频都有对应的真实视频(或近似对应的视频)。 + """ + + raise NotImplementedError + + def other_parameter(self): + if hasattr(self, "other_parameter_name"): + return {key: getattr(self, key) for key in self.other_parameter_name} + return dict() + + def load_other_parameter(self, para): + if para is None: + return + for key, value in para.items(): + setattr(self, key, value) + + def __str__(self): + """ + Model prints with number of trainable parameters + """ + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return ( + super().__str__() + + set_color("\nTrainable parameters", "blue") + + f": {params}" + ) + + diff --git a/talkingface-toolkit-main/talkingface/model/audio_driven_talkingface/__init__.py b/talkingface-toolkit-main/talkingface/model/audio_driven_talkingface/__init__.py new file mode 100644 index 00000000..04a35f33 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/audio_driven_talkingface/__init__.py @@ -0,0 +1 @@ +from talkingface.model.audio_driven_talkingface.wav2lip import Wav2Lip, SyncNet_color \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/model/image_driven_talkingface/__init__.py b/talkingface-toolkit-main/talkingface/model/image_driven_talkingface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface-toolkit-main/talkingface/model/layers.py b/talkingface-toolkit-main/talkingface/model/layers.py new file mode 100644 index 00000000..ed83da00 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/layers.py @@ -0,0 +1,44 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class nonorm_Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + ) + self.act = nn.LeakyReLU(0.01, inplace=True) + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) + +class Conv2dTranspose(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/__init__.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/__init__.py new file mode 100644 index 00000000..761f9524 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/__init__.py @@ -0,0 +1,142 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" +import os +import importlib +import numpy as np +import torch +import torch.nn as nn +from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.base_model import BaseModel + + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "talkingface.model.audio_driven_talkingface.LiveSpeechPortraits." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit() + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: +# >>> from models import create_model +# >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model_name) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance + + +def save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, modelG, modelD, end_of_epoch=False): + if not end_of_epoch: + if total_steps % opt.save_latest_freq == 0: + visualizer.vis_print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) + modelG.module.save('latest') + modelD.module.save('latest') + np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') + else: + if epoch % opt.save_epoch_freq == 0: + visualizer.vis_print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) + modelG.module.save('latest') + modelD.module.save('latest') + modelG.module.save(epoch) + modelD.module.save(epoch) + np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') + + +def update_models(opt, epoch, modelG, modelD, dataset_warp): + ### linearly decay learning rate after certain iterations + if epoch > opt.niter: + modelG.module.update_learning_rate(epoch, 'G') + modelD.module.update_learning_rate(epoch, 'D') + + ### gradually grow training sequence length + if (epoch % opt.niter_step) == 0: + dataset_warp.dataset.update_training_batch(epoch//opt.niter_step) +# modelG.module.update_training_batch(epoch//opt.niter_step) + + ### finetune all scales + if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): + modelG.module.update_fixed_params() + + +class myModel(nn.Module): + def __init__(self, opt, model): + super(myModel, self).__init__() + self.opt = opt + self.module = model + self.model = nn.DataParallel(model, device_ids=opt.gpu_ids) + self.bs_per_gpu = int(np.ceil(float(opt.batch_size) / len(opt.gpu_ids))) # batch size for each GPU + self.pad_bs = self.bs_per_gpu * len(opt.gpu_ids) - opt.batch_size + + def forward(self, *inputs, **kwargs): + inputs = self.add_dummy_to_tensor(inputs, self.pad_bs) + outputs = self.model(*inputs, **kwargs, dummy_bs=self.pad_bs) + if self.pad_bs == self.bs_per_gpu: # gpu 0 does 0 batch but still returns 1 batch + return self.remove_dummy_from_tensor(outputs, 1) + return outputs + + def add_dummy_to_tensor(self, tensors, add_size=0): + if add_size == 0 or tensors is None: return tensors + if type(tensors) == list or type(tensors) == tuple: + return [self.add_dummy_to_tensor(tensor, add_size) for tensor in tensors] + + if isinstance(tensors, torch.Tensor): + dummy = torch.zeros_like(tensors)[:add_size] + tensors = torch.cat([dummy, tensors]) + return tensors + + def remove_dummy_from_tensor(self, tensors, remove_size=0): + if remove_size == 0 or tensors is None: return tensors + if type(tensors) == list or type(tensors) == tuple: + return [self.remove_dummy_from_tensor(tensor, remove_size) for tensor in tensors] + + if isinstance(tensors, torch.Tensor): + tensors = tensors[remove_size:] + return tensors + + diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_features.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_features.py new file mode 100644 index 00000000..78bba1ec --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_features.py @@ -0,0 +1,275 @@ +""" + DeepSpeech features processing routines. + NB: Based on VOCA code. See the corresponding license restrictions. +""" + +__all__ = ['conv_audios_to_deepspeech'] + +import numpy as np +import warnings +import resampy +from scipy.io import wavfile +from python_speech_features import mfcc +import tensorflow as tf + + +def conv_audios_to_deepspeech(audios, + out_files, + num_frames_info, + deepspeech_pb_path, + audio_window_size=1, + audio_window_stride=1): + """ + Convert list of audio files into files with DeepSpeech features. + + Parameters + ---------- + audios : list of str or list of None + Paths to input audio files. + out_files : list of str + Paths to output files with DeepSpeech features. + num_frames_info : list of int + List of numbers of frames. + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + audio_window_size : int, default 16 + Audio window size. + audio_window_stride : int, default 1 + Audio window stride. + """ + # deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" + graph, logits_ph, input_node_ph, input_lengths_ph = prepare_deepspeech_net( + deepspeech_pb_path) + + with tf.compat.v1.Session(graph=graph) as sess: + for audio_file_path, out_file_path, num_frames in zip(audios, out_files, num_frames_info): + print(audio_file_path) + print(out_file_path) + audio_sample_rate, audio = wavfile.read(audio_file_path) + if audio.ndim != 1: + warnings.warn( + "Audio has multiple channels, the first channel is used") + audio = audio[:, 0] + ds_features = pure_conv_audio_to_deepspeech( + audio=audio, + audio_sample_rate=audio_sample_rate, + audio_window_size=audio_window_size, + audio_window_stride=audio_window_stride, + num_frames=num_frames, + net_fn=lambda x: sess.run( + logits_ph, + feed_dict={ + input_node_ph: x[np.newaxis, ...], + input_lengths_ph: [x.shape[0]]})) + + net_output = ds_features.reshape(-1, 29) + win_size = 16 + zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) + net_output = np.concatenate( + (zero_pad, net_output, zero_pad), axis=0) + windows = [] + for window_index in range(0, net_output.shape[0] - win_size, 2): + windows.append( + net_output[window_index:window_index + win_size]) + print(np.array(windows).shape) + np.save(out_file_path, np.array(windows)) + + +def prepare_deepspeech_net(deepspeech_pb_path): + """ + Load and prepare DeepSpeech network. + + Parameters + ---------- + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + + Returns + ------- + graph : obj + ThensorFlow graph. + logits_ph : obj + ThensorFlow placeholder for `logits`. + input_node_ph : obj + ThensorFlow placeholder for `input_node`. + input_lengths_ph : obj + ThensorFlow placeholder for `input_lengths`. + """ + # Load graph and place_holders: + with tf.io.gfile.GFile(deepspeech_pb_path, "rb") as f: + graph_def = tf.compat.v1.GraphDef() + graph_def.ParseFromString(f.read()) + + graph = tf.compat.v1.get_default_graph() + tf.import_graph_def(graph_def, name="deepspeech") + logits_ph = graph.get_tensor_by_name("deepspeech/logits:0") + input_node_ph = graph.get_tensor_by_name("deepspeech/input_node:0") + input_lengths_ph = graph.get_tensor_by_name("deepspeech/input_lengths:0") + + return graph, logits_ph, input_node_ph, input_lengths_ph + + +def pure_conv_audio_to_deepspeech(audio, + audio_sample_rate, + audio_window_size, + audio_window_stride, + num_frames, + net_fn): + """ + Core routine for converting audion into DeepSpeech features. + + Parameters + ---------- + audio : np.array + Audio data. + audio_sample_rate : int + Audio sample rate. + audio_window_size : int + Audio window size. + audio_window_stride : int + Audio window stride. + num_frames : int or None + Numbers of frames. + net_fn : func + Function for DeepSpeech model call. + + Returns + ------- + np.array + DeepSpeech features. + """ + target_sample_rate = 16000 + if audio_sample_rate != target_sample_rate: + resampled_audio = resampy.resample( + x=audio.astype(np.float), + sr_orig=audio_sample_rate, + sr_new=target_sample_rate) + else: + resampled_audio = audio.astype(np.float) + input_vector = conv_audio_to_deepspeech_input_vector( + audio=resampled_audio.astype(np.int16), + sample_rate=target_sample_rate, + num_cepstrum=26, + num_context=9) + + network_output = net_fn(input_vector) + # print(network_output.shape) + + deepspeech_fps = 50 + video_fps = 50 # Change this option if video fps is different + audio_len_s = float(audio.shape[0]) / audio_sample_rate + if num_frames is None: + num_frames = int(round(audio_len_s * video_fps)) + else: + video_fps = num_frames / audio_len_s + network_output = interpolate_features( + features=network_output[:, 0], + input_rate=deepspeech_fps, + output_rate=video_fps, + output_len=num_frames) + + # Make windows: + zero_pad = np.zeros((int(audio_window_size / 2), network_output.shape[1])) + network_output = np.concatenate( + (zero_pad, network_output, zero_pad), axis=0) + windows = [] + for window_index in range(0, network_output.shape[0] - audio_window_size, audio_window_stride): + windows.append( + network_output[window_index:window_index + audio_window_size]) + + return np.array(windows) + + +def conv_audio_to_deepspeech_input_vector(audio, + sample_rate, + num_cepstrum, + num_context): + """ + Convert audio raw data into DeepSpeech input vector. + + Parameters + ---------- + audio : np.array + Audio data. + audio_sample_rate : int + Audio sample rate. + num_cepstrum : int + Number of cepstrum. + num_context : int + Number of context. + + Returns + ------- + np.array + DeepSpeech input vector. + """ + # Get mfcc coefficients: + features = mfcc( + signal=audio, + samplerate=sample_rate, + numcep=num_cepstrum) + + # We only keep every second feature (BiRNN stride = 2): + features = features[::2] + + # One stride per time step in the input: + num_strides = len(features) + + # Add empty initial and final contexts: + empty_context = np.zeros((num_context, num_cepstrum), dtype=features.dtype) + features = np.concatenate((empty_context, features, empty_context)) + + # Create a view into the array with overlapping strides of size + # numcontext (past) + 1 (present) + numcontext (future): + window_size = 2 * num_context + 1 + train_inputs = np.lib.stride_tricks.as_strided( + features, + shape=(num_strides, window_size, num_cepstrum), + strides=(features.strides[0], + features.strides[0], features.strides[1]), + writeable=False) + + # Flatten the second and third dimensions: + train_inputs = np.reshape(train_inputs, [num_strides, -1]) + + train_inputs = np.copy(train_inputs) + train_inputs = (train_inputs - np.mean(train_inputs)) / \ + np.std(train_inputs) + + return train_inputs + + +def interpolate_features(features, + input_rate, + output_rate, + output_len): + """ + Interpolate DeepSpeech features. + + Parameters + ---------- + features : np.array + DeepSpeech features. + input_rate : int + input rate (FPS). + output_rate : int + Output rate (FPS). + output_len : int + Output data length. + + Returns + ------- + np.array + Interpolated data. + """ + input_len = features.shape[0] + num_features = features.shape[1] + input_timestamps = np.arange(input_len) / float(input_rate) + output_timestamps = np.arange(output_len) / float(output_rate) + output_features = np.zeros((output_len, num_features)) + for feature_idx in range(num_features): + output_features[:, feature_idx] = np.interp( + x=output_timestamps, + xp=input_timestamps, + fp=features[:, feature_idx]) + return output_features diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_store.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_store.py new file mode 100644 index 00000000..5595a4d5 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/deepspeech_store.py @@ -0,0 +1,172 @@ +""" + Routines for loading DeepSpeech model. +""" + +__all__ = ['get_deepspeech_model_file'] + +import os +import zipfile +import logging +import hashlib + + +deepspeech_features_repo_url = 'https://github.com/osmr/deepspeech_features' + + +def get_deepspeech_model_file(local_model_store_dir_path=os.path.join("~", ".tensorflow", "models")): + """ + Return location for the pretrained on local file system. This function will download from online model zoo when + model cannot be found or has mismatch. The root directory will be created if it doesn't exist. + + Parameters + ---------- + local_model_store_dir_path : str, default $TENSORFLOW_HOME/models + Location for keeping the model parameters. + + Returns + ------- + file_path + Path to the requested pretrained model file. + """ + sha1_hash = "b90017e816572ddce84f5843f1fa21e6a377975e" + file_name = "deepspeech-0_1_0-b90017e8.pb" + local_model_store_dir_path = os.path.expanduser(local_model_store_dir_path) + file_path = os.path.join(local_model_store_dir_path, file_name) + if os.path.exists(file_path): + if _check_sha1(file_path, sha1_hash): + return file_path + else: + logging.warning("Mismatch in the content of model file detected. Downloading again.") + else: + logging.info("Model file not found. Downloading to {}.".format(file_path)) + + if not os.path.exists(local_model_store_dir_path): + os.makedirs(local_model_store_dir_path) + + zip_file_path = file_path + ".zip" + _download( + url="{repo_url}/releases/download/{repo_release_tag}/{file_name}.zip".format( + repo_url=deepspeech_features_repo_url, + repo_release_tag="v0.0.1", + file_name=file_name), + path=zip_file_path, + overwrite=True) + with zipfile.ZipFile(zip_file_path) as zf: + zf.extractall(local_model_store_dir_path) + os.remove(zip_file_path) + + if _check_sha1(file_path, sha1_hash): + return file_path + else: + raise ValueError("Downloaded file has different hash. Please try again.") + + +def _download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True): + """ + Download an given URL + + Parameters + ---------- + url : str + URL to download + path : str, optional + Destination path to store downloaded file. By default stores to the + current directory with same name as in url. + overwrite : bool, optional + Whether to overwrite destination file if already exists. + sha1_hash : str, optional + Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified + but doesn't match. + retries : integer, default 5 + The number of times to attempt the download in case of failure or non 200 return codes + verify_ssl : bool, default True + Verify SSL certificates. + + Returns + ------- + str + The file path of the downloaded file. + """ + import warnings + try: + import requests + except ImportError: + class requests_failed_to_import(object): + pass + requests = requests_failed_to_import + + if path is None: + fname = url.split("/")[-1] + # Empty filenames are invalid + assert fname, "Can't construct file-name from this URL. Please set the `path` option manually." + else: + path = os.path.expanduser(path) + if os.path.isdir(path): + fname = os.path.join(path, url.split("/")[-1]) + else: + fname = path + assert retries >= 0, "Number of retries should be at least 0" + + if not verify_ssl: + warnings.warn( + "Unverified HTTPS request is being made (verify_ssl=False). " + "Adding certificate verification is strongly advised.") + + if overwrite or not os.path.exists(fname) or (sha1_hash and not _check_sha1(fname, sha1_hash)): + dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) + if not os.path.exists(dirname): + os.makedirs(dirname) + while retries + 1 > 0: + # Disable pyling too broad Exception + # pylint: disable=W0703 + try: + print("Downloading {} from {}...".format(fname, url)) + r = requests.get(url, stream=True, verify=verify_ssl) + if r.status_code != 200: + raise RuntimeError("Failed downloading url {}".format(url)) + with open(fname, "wb") as f: + for chunk in r.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if sha1_hash and not _check_sha1(fname, sha1_hash): + raise UserWarning("File {} is downloaded but the content hash does not match." + " The repo may be outdated or download may be incomplete. " + "If the `repo_url` is overridden, consider switching to " + "the default repo.".format(fname)) + break + except Exception as e: + retries -= 1 + if retries <= 0: + raise e + else: + print("download failed, retrying, {} attempt{} left" + .format(retries, "s" if retries > 1 else "")) + + return fname + + +def _check_sha1(filename, sha1_hash): + """ + Check whether the sha1 hash of the file content matches the expected hash. + + Parameters + ---------- + filename : str + Path to the file. + sha1_hash : str + Expected sha1 hash in hexadecimal digits. + + Returns + ------- + bool + Whether the file content matches the expected hash. + """ + sha1 = hashlib.sha1() + with open(filename, "rb") as f: + while True: + data = f.read(1048576) + if not data: + break + sha1.update(data) + + return sha1.hexdigest() == sha1_hash diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_ds_features.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_ds_features.py new file mode 100644 index 00000000..34c1fe14 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_ds_features.py @@ -0,0 +1,129 @@ +""" + Script for extracting DeepSpeech features from audio file. +""" + +import os +import argparse +import numpy as np +import pandas as pd +from deepspeech_store import get_deepspeech_model_file +from deepspeech_features import conv_audios_to_deepspeech + + +def parse_args(): + """ + Create python script parameters. + Returns + ------- + ArgumentParser + Resulted args. + """ + parser = argparse.ArgumentParser( + description="Extract DeepSpeech features from audio file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--input", + type=str, + required=True, + help="path to input audio file or directory") + parser.add_argument( + "--output", + type=str, + help="path to output file with DeepSpeech features") + parser.add_argument( + "--deepspeech", + type=str, + help="path to DeepSpeech 0.1.0 frozen model") + parser.add_argument( + "--metainfo", + type=str, + help="path to file with meta-information") + + args = parser.parse_args() + return args + + +def extract_features(in_audios, + out_files, + deepspeech_pb_path, + metainfo_file_path=None): + """ + Real extract audio from video file. + Parameters + ---------- + in_audios : list of str + Paths to input audio files. + out_files : list of str + Paths to output files with DeepSpeech features. + deepspeech_pb_path : str + Path to DeepSpeech 0.1.0 frozen model. + metainfo_file_path : str, default None + Path to file with meta-information. + """ + #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" + if metainfo_file_path is None: + num_frames_info = [None] * len(in_audios) + else: + train_df = pd.read_csv( + metainfo_file_path, + sep="\t", + index_col=False, + dtype={"Id": np.int, "File": np.unicode, "Count": np.int}) + num_frames_info = train_df["Count"].values + assert (len(num_frames_info) == len(in_audios)) + + for i, in_audio in enumerate(in_audios): + if not out_files[i]: + file_stem, _ = os.path.splitext(in_audio) + out_files[i] = file_stem + ".npy" + #print(out_files[i]) + conv_audios_to_deepspeech( + audios=in_audios, + out_files=out_files, + num_frames_info=num_frames_info, + deepspeech_pb_path=deepspeech_pb_path) + + +def main(): + """ + Main body of script. + """ + args = parse_args() + in_audio = os.path.expanduser(args.input) + if not os.path.exists(in_audio): + raise Exception("Input file/directory doesn't exist: {}".format(in_audio)) + deepspeech_pb_path = args.deepspeech + #deepspeech_pb_path="/disk4/keyu/DeepSpeech/deepspeech-0.9.2-models.pbmm" + if deepspeech_pb_path is None: + deepspeech_pb_path = "" + if deepspeech_pb_path: + deepspeech_pb_path = os.path.expanduser(args.deepspeech) + if not os.path.exists(deepspeech_pb_path): + deepspeech_pb_path = get_deepspeech_model_file() + if os.path.isfile(in_audio): + extract_features( + in_audios=[in_audio], + out_files=[args.output], + deepspeech_pb_path=deepspeech_pb_path, + metainfo_file_path=args.metainfo) + else: + audio_file_paths = [] + for file_name in os.listdir(in_audio): + if not os.path.isfile(os.path.join(in_audio, file_name)): + continue + _, file_ext = os.path.splitext(file_name) + if file_ext.lower() == ".wav": + audio_file_path = os.path.join(in_audio, file_name) + audio_file_paths.append(audio_file_path) + audio_file_paths = sorted(audio_file_paths) + out_file_paths = [""] * len(audio_file_paths) + extract_features( + in_audios=audio_file_paths, + out_files=out_file_paths, + deepspeech_pb_path=deepspeech_pb_path, + metainfo_file_path=args.metainfo) + + +if __name__ == "__main__": + main() + diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_wav.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_wav.py new file mode 100644 index 00000000..8458c5f2 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/extract_wav.py @@ -0,0 +1,87 @@ +""" + Script for extracting audio (16-bit, mono, 22000 Hz) from video file. +""" + +import os +import argparse +import subprocess + + +def parse_args(): + """ + Create python script parameters. + + Returns + ------- + ArgumentParser + Resulted args. + """ + parser = argparse.ArgumentParser( + description="Extract audio from video file", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--in-video", + type=str, + required=True, + help="path to input video file or directory") + parser.add_argument( + "--out-audio", + type=str, + help="path to output audio file") + + args = parser.parse_args() + return args + + +def extract_audio(in_video, + out_audio): + """ + Real extract audio from video file. + + Parameters + ---------- + in_video : str + Path to input video file. + out_audio : str + Path to output audio file. + """ + if not out_audio: + file_stem, _ = os.path.splitext(in_video) + out_audio = file_stem + ".wav" + # command1 = "ffmpeg -i {in_video} -vn -acodec copy {aac_audio}" + # command2 = "ffmpeg -i {aac_audio} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" + # command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 22000 {out_audio}" + command = "ffmpeg -i {in_video} -vn -acodec pcm_s16le -ac 1 -ar 16000 {out_audio}" + subprocess.call([command.format(in_video=in_video, out_audio=out_audio)], shell=True) + + +def main(): + """ + Main body of script. + """ + args = parse_args() + in_video = os.path.expanduser(args.in_video) + if not os.path.exists(in_video): + raise Exception("Input file/directory doesn't exist: {}".format(in_video)) + if os.path.isfile(in_video): + extract_audio( + in_video=in_video, + out_audio=args.out_audio) + else: + video_file_paths = [] + for file_name in os.listdir(in_video): + if not os.path.isfile(os.path.join(in_video, file_name)): + continue + _, file_ext = os.path.splitext(file_name) + if file_ext.lower() in (".mp4", ".mkv", ".avi"): + video_file_path = os.path.join(in_video, file_name) + video_file_paths.append(video_file_path) + video_file_paths = sorted(video_file_paths) + for video_file_path in video_file_paths: + extract_audio( + in_video=video_file_path, + out_audio="") + + +if __name__ == "__main__": + main() diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/fea_win.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/fea_win.py new file mode 100644 index 00000000..df9e27b4 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/fea_win.py @@ -0,0 +1,11 @@ +import numpy as np + +net_output = np.load('french.ds.npy').reshape(-1, 29) +win_size = 16 +zero_pad = np.zeros((int(win_size / 2), net_output.shape[1])) +net_output = np.concatenate((zero_pad, net_output, zero_pad), axis=0) +windows = [] +for window_index in range(0, net_output.shape[0] - win_size, 2): + windows.append(net_output[window_index:window_index + win_size]) +print(np.array(windows).shape) +np.save('aud_french.npy', np.array(windows)) diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/load_audface.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/load_audface.py new file mode 100644 index 00000000..b6d59727 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/load_audface.py @@ -0,0 +1,114 @@ +import os +import torch +import numpy as np +import imageio +import json +import torch.nn.functional as F +import cv2 + + +def load_audface_data(basedir, testskip=1, test_file=None, aud_file=None, test_size=-1): + if test_file is not None: + with open(os.path.join(basedir, test_file)) as fp: + meta = json.load(fp) + poses = [] + auds = [] + aud_features = np.load(os.path.join(basedir, aud_file)) + cur_id = 0 + for frame in meta['frames'][::testskip]: + poses.append(np.array(frame['transform_matrix'])) + aud_id = cur_id + auds.append(aud_features[aud_id]) + cur_id = cur_id + 1 + if cur_id == aud_features.shape[0] or cur_id == test_size: + break + poses = np.array(poses).astype(np.float32) + auds = np.array(auds).astype(np.float32) + bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg')) + H, W = bc_img.shape[0], bc_img.shape[1] + focal, cx, cy = float(meta['focal_len']), float( + meta['cx']), float(meta['cy']) + return poses, auds, bc_img, [H, W, focal, cx, cy] + + splits = ['train', 'val'] + metas = {} + for s in splits: + with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: + metas[s] = json.load(fp) + all_com_imgs = [] + all_poses = [] + all_auds = [] + all_sample_rects = [] + aud_features = np.load(os.path.join(basedir, 'aud.npy')) + counts = [0] + for s in splits: + meta = metas[s] + com_imgs = [] + poses = [] + auds = [] + sample_rects = [] + if s == 'train' or testskip == 0: + skip = 1 + else: + skip = testskip + + for frame in meta['frames'][::skip]: + filename = os.path.join(basedir, 'com_imgs', + str(frame['img_id']) + '.jpg') + com_imgs.append(filename) + poses.append(np.array(frame['transform_matrix'])) + auds.append( + aud_features[min(frame['aud_id'], aud_features.shape[0]-1)]) + sample_rects.append(np.array(frame['face_rect'], dtype=np.int32)) + com_imgs = np.array(com_imgs) + poses = np.array(poses).astype(np.float32) + auds = np.array(auds).astype(np.float32) + counts.append(counts[-1] + com_imgs.shape[0]) + all_com_imgs.append(com_imgs) + all_poses.append(poses) + all_auds.append(auds) + all_sample_rects.append(sample_rects) + i_split = [np.arange(counts[i], counts[i+1]) for i in range(len(splits))] + com_imgs = np.concatenate(all_com_imgs, 0) + poses = np.concatenate(all_poses, 0) + auds = np.concatenate(all_auds, 0) + sample_rects = np.concatenate(all_sample_rects, 0) + + bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg')) + + H, W = bc_img.shape[:2] + focal, cx, cy = float(meta['focal_len']), float( + meta['cx']), float(meta['cy']) + + return com_imgs, poses, auds, bc_img, [H, W, focal, cx, cy], \ + sample_rects, i_split + + +def load_test_data(basedir, aud_file, test_pose_file='transforms_train.json', + testskip=1, test_size=-1, aud_start=0): + with open(os.path.join(basedir, test_pose_file)) as fp: + meta = json.load(fp) + poses = [] + auds = [] + aud_features = np.load(aud_file) + aud_ids = [] + cur_id = 0 + for frame in meta['frames'][::testskip]: + poses.append(np.array(frame['transform_matrix'])) + auds.append( + aud_features[min(aud_start+cur_id, aud_features.shape[0]-1)]) + aud_ids.append(aud_start+cur_id) + cur_id = cur_id + 1 + if cur_id == test_size or cur_id == aud_features.shape[0]: + break + poses = np.array(poses).astype(np.float32) + auds = np.array(auds).astype(np.float32) + bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg')) + H, W = bc_img.shape[0], bc_img.shape[1] + focal, cx, cy = float(meta['focal_len']), float( + meta['cx']), float(meta['cy']) + + with open(os.path.join(basedir, 'transforms_train.json')) as fp: + meta_torso = json.load(fp) + torso_pose = np.array(meta_torso['frames'][0]['transform_matrix']) + return poses, auds, bc_img, [H, W, focal, cx, cy], aud_ids, torso_pose diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf.py new file mode 100644 index 00000000..b912f6c6 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf.py @@ -0,0 +1,1136 @@ +from load_audface import load_audface_data, load_test_data +import os +import sys +import numpy as np +import imageio +import json +import random +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm, trange +from natsort import natsorted +import cv2 + +from run_nerf_helpers import * + +device = torch.device('cuda', 0) +device_torso = torch.device('cuda', 0) +np.random.seed(0) +DEBUG = False + + +def rot_to_euler(R): + batch_size, _, _ = R.shape + e = torch.ones((batch_size, 3)).cuda() + + R00 = R[:, 0, 0] + R01 = R[:, 0, 1] + R02 = R[:, 0, 2] + R10 = R[:, 1, 0] + R11 = R[:, 1, 1] + R12 = R[:, 1, 2] + R20 = R[:, 2, 0] + R21 = R[:, 2, 1] + R22 = R[:, 2, 2] + e[:, 2] = torch.atan2(R00, -R01) + e[:, 1] = torch.asin(-R02) + e[:, 0] = torch.atan2(R22, R12) + return e + + +def pose_to_euler_trans(poses): + e = rot_to_euler(poses) + t = poses[:, :3, 3] + return torch.cat((e, t), dim=1) + + +def euler2rot(euler_angle): + batch_size = euler_angle.shape[0] + theta = euler_angle[:, 0].reshape(-1, 1, 1) + phi = euler_angle[:, 1].reshape(-1, 1, 1) + psi = euler_angle[:, 2].reshape(-1, 1, 1) + one = torch.ones((batch_size, 1, 1), dtype=torch.float32, + device=euler_angle.device) + zero = torch.zeros((batch_size, 1, 1), dtype=torch.float32, + device=euler_angle.device) + rot_x = torch.cat(( + torch.cat((one, zero, zero), 1), + torch.cat((zero, theta.cos(), theta.sin()), 1), + torch.cat((zero, -theta.sin(), theta.cos()), 1), + ), 2) + rot_y = torch.cat(( + torch.cat((phi.cos(), zero, -phi.sin()), 1), + torch.cat((zero, one, zero), 1), + torch.cat((phi.sin(), zero, phi.cos()), 1), + ), 2) + rot_z = torch.cat(( + torch.cat((psi.cos(), -psi.sin(), zero), 1), + torch.cat((psi.sin(), psi.cos(), zero), 1), + torch.cat((zero, zero, one), 1) + ), 2) + return torch.bmm(rot_x, torch.bmm(rot_y, rot_z)) + + +def batchify(fn, chunk): + """Constructs a version of 'fn' that applies to smaller batches. + """ + if chunk is None: + return fn + + def ret(inputs): + return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) + return ret + + +def run_network(inputs, viewdirs, aud_para, fn, embed_fn, embeddirs_fn, netchunk=1024*64): + """Prepares inputs and applies network 'fn'. + """ + inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) + embedded = embed_fn(inputs_flat) + aud = aud_para.unsqueeze(0).expand(inputs_flat.shape[0], -1) + embedded = torch.cat((embedded, aud), -1) + if viewdirs is not None: + input_dirs = viewdirs[:, None].expand(inputs.shape) + input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) + embedded_dirs = embeddirs_fn(input_dirs_flat) + embedded = torch.cat([embedded, embedded_dirs], -1) + + outputs_flat = batchify(fn, netchunk)(embedded) + outputs = torch.reshape(outputs_flat, list( + inputs.shape[:-1]) + [outputs_flat.shape[-1]]) + return outputs + + +def batchify_rays(rays_flat, bc_rgb, aud_para, chunk=1024*32, **kwargs): + """Render rays in smaller minibatches to avoid OOM. + """ + all_ret = {} + for i in range(0, rays_flat.shape[0], chunk): + ret = render_rays(rays_flat[i:i+chunk], bc_rgb[i:i+chunk], + aud_para, **kwargs) + for k in ret: + if k not in all_ret: + all_ret[k] = [] + all_ret[k].append(ret[k]) + + all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret} + return all_ret + + +def render_dynamic_face(H, W, focal, cx, cy, chunk=1024*32, rays=None, bc_rgb=None, aud_para=None, + c2w=None, ndc=True, near=0., far=1., + use_viewdirs=False, c2w_staticcam=None, + **kwargs): + if c2w is not None: + # special case to render full image + rays_o, rays_d = get_rays(H, W, focal, c2w, cx, cy, c2w.device) + bc_rgb = bc_rgb.reshape(-1, 3) + else: + # use provided ray batch + rays_o, rays_d = rays + + if use_viewdirs: + # provide ray directions as input + viewdirs = rays_d + if c2w_staticcam is not None: + # special case to visualize effect of viewdirs + rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam, cx, cy) + viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) + viewdirs = torch.reshape(viewdirs, [-1, 3]).float() + + sh = rays_d.shape # [..., 3] + if ndc: + # for forward facing scenes + rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d) + + # Create ray batch + rays_o = torch.reshape(rays_o, [-1, 3]).float() + rays_d = torch.reshape(rays_d, [-1, 3]).float() + + near, far = near * \ + torch.ones_like(rays_d[..., :1]), far * \ + torch.ones_like(rays_d[..., :1]) + rays = torch.cat([rays_o, rays_d, near, far], -1) + if use_viewdirs: + rays = torch.cat([rays, viewdirs], -1) + + # Render and reshape + all_ret = batchify_rays(rays, bc_rgb, aud_para, chunk, **kwargs) + for k in all_ret: + k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) + all_ret[k] = torch.reshape(all_ret[k], k_sh) + + k_extract = ['rgb_map', 'disp_map', 'acc_map', 'last_weight', 'rgb_map_fg'] + ret_list = [all_ret[k] for k in k_extract] + ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract} + return ret_list + [ret_dict] + + +def render(H, W, focal, cx, cy, chunk=1024*32, rays=None, c2w=None, ndc=True, + near=0., far=1., + use_viewdirs=False, c2w_staticcam=None, + **kwargs): + """Render rays + Args: + H: int. Height of image in pixels. + W: int. Width of image in pixels. + focal: float. Focal length of pinhole camera. + chunk: int. Maximum number of rays to process simultaneously. Used to + control maximum memory usage. Does not affect final results. + rays: array of shape [2, batch_size, 3]. Ray origin and direction for + each example in batch. + c2w: array of shape [3, 4]. Camera-to-world transformation matrix. + ndc: bool. If True, represent ray origin, direction in NDC coordinates. + near: float or array of shape [batch_size]. Nearest distance for a ray. + far: float or array of shape [batch_size]. Farthest distance for a ray. + use_viewdirs: bool. If True, use viewing direction of a point in space in model. + c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for + camera while using other c2w argument for viewing directions. + Returns: + rgb_map: [batch_size, 3]. Predicted RGB values for rays. + disp_map: [batch_size]. Disparity map. Inverse of depth. + acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. + extras: dict with everything returned by render_rays(). + """ + if c2w is not None: + # special case to render full image + rays_o, rays_d = get_rays(H, W, focal, c2w, cx, cy) + else: + # use provided ray batch + rays_o, rays_d = rays + + if use_viewdirs: + # provide ray directions as input + viewdirs = rays_d + if c2w_staticcam is not None: + # special case to visualize effect of viewdirs + rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam, cx, cy) + viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) + viewdirs = torch.reshape(viewdirs, [-1, 3]).float() + + sh = rays_d.shape # [..., 3] + if ndc: + # for forward facing scenes + rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d) + + # Create ray batch + rays_o = torch.reshape(rays_o, [-1, 3]).float() + rays_d = torch.reshape(rays_d, [-1, 3]).float() + + near, far = near * \ + torch.ones_like(rays_d[..., :1]), far * \ + torch.ones_like(rays_d[..., :1]) + rays = torch.cat([rays_o, rays_d, near, far], -1) + if use_viewdirs: + rays = torch.cat([rays, viewdirs], -1) + + # Render and reshape + all_ret = batchify_rays(rays, chunk, **kwargs) + for k in all_ret: + k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) + all_ret[k] = torch.reshape(all_ret[k], k_sh) + + k_extract = ['rgb_map', 'disp_map', 'acc_map'] + ret_list = [all_ret[k] for k in k_extract] + ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract} + return ret_list + [ret_dict] + + +def render_path(render_poses, aud_paras, bc_img, hwfcxy, + chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): + + H, W, focal, cx, cy = hwfcxy + + if render_factor != 0: + # Render downsampled for speed + H = H//render_factor + W = W//render_factor + focal = focal/render_factor + + rgbs = [] + disps = [] + last_weights = [] + rgb_fgs = [] + + t = time.time() + for i, c2w in enumerate(tqdm(render_poses)): + print(i, time.time() - t) + t = time.time() + rgb, disp, acc, last_weight, rgb_fg, _ = render_dynamic_face( + H, W, focal, cx, cy, chunk=chunk, c2w=c2w[:3, + :4], aud_para=aud_paras[i], bc_rgb=bc_img, + **render_kwargs) + rgbs.append(rgb.cpu().numpy()) + disps.append(disp.cpu().numpy()) + last_weights.append(last_weight.cpu().numpy()) + rgb_fgs.append(rgb_fg.cpu().numpy()) + # if i == 0: + # print(rgb.shape, disp.shape) + + """ + if gt_imgs is not None and render_factor==0: + p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) + print(p) + """ + + if savedir is not None: + rgb8 = to8b(rgbs[-1]) + filename = os.path.join(savedir, '{:03d}.png'.format(i)) + imageio.imwrite(filename, rgb8) + + rgbs = np.stack(rgbs, 0) + disps = np.stack(disps, 0) + last_weights = np.stack(last_weights, 0) + rgb_fgs = np.stack(rgb_fgs, 0) + + return rgbs, disps, last_weights, rgb_fgs + + +def create_nerf(args, ext, dim_aud, device_spec=torch.device('cuda', 0), with_audatt=False): + """Instantiate NeRF's MLP model. + """ + embed_fn, input_ch = get_embedder( + args.multires, args.i_embed, device=device_spec) + + input_ch_views = 0 + embeddirs_fn = None + if args.use_viewdirs: + embeddirs_fn, input_ch_views = get_embedder( + args.multires_views, args.i_embed, device=device_spec) + output_ch = 5 if args.N_importance > 0 else 4 + skips = [4] + model = FaceNeRF(D=args.netdepth, W=args.netwidth, + input_ch=input_ch, dim_aud=dim_aud, + output_ch=output_ch, skips=skips, + input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device_spec) + grad_vars = list(model.parameters()) + + model_fine = None + if args.N_importance > 0: + model_fine = FaceNeRF(D=args.netdepth_fine, W=args.netwidth_fine, + input_ch=input_ch, dim_aud=dim_aud, + output_ch=output_ch, skips=skips, + input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device_spec) + grad_vars += list(model_fine.parameters()) + + def network_query_fn(inputs, viewdirs, aud_para, network_fn): \ + return run_network(inputs, viewdirs, aud_para, network_fn, + embed_fn=embed_fn, embeddirs_fn=embeddirs_fn, netchunk=args.netchunk) + + # Create optimizer + optimizer = torch.optim.Adam( + params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) + + start = 0 + basedir = args.basedir + expname = args.expname + + ########################## + + # Load checkpoints + if args.ft_path is not None and args.ft_path != 'None': + ckpts = [args.ft_path] + else: + ckpts = [os.path.join(basedir, expname, f) for f in natsorted( + os.listdir(os.path.join(basedir, expname))) if ext in f] + + print('Found ckpts', ckpts) + learned_codes_dict = None + AudNet_state = None + optimizer_aud_state = None + AudAttNet_state = None + if len(ckpts) > 0 and not args.no_reload: + ckpt_path = ckpts[-1] + print('Reloading from', ckpt_path) + ckpt = torch.load(ckpt_path, map_location=device) + + start = ckpt['global_step'] + optimizer.load_state_dict(ckpt['optimizer_state_dict']) + AudNet_state = ckpt['network_audnet_state_dict'] + optimizer_aud_state = ckpt['optimizer_aud_state_dict'] + if with_audatt: + AudAttNet_state = ckpt['network_audattnet_state_dict'] + + # Load model + model.load_state_dict(ckpt['network_fn_state_dict']) + if model_fine is not None: + model_fine.load_state_dict(ckpt['network_fine_state_dict']) + + ########################## + + render_kwargs_train = { + 'network_query_fn': network_query_fn, + 'perturb': args.perturb, + 'N_importance': args.N_importance, + 'network_fine': model_fine, + 'N_samples': args.N_samples, + 'network_fn': model, + 'use_viewdirs': args.use_viewdirs, + 'white_bkgd': args.white_bkgd, + 'raw_noise_std': args.raw_noise_std, + } + + # NDC only good for LLFF-style forward facing data + if args.dataset_type != 'llff' or args.no_ndc: + print('Not ndc!') + render_kwargs_train['ndc'] = False + render_kwargs_train['lindisp'] = args.lindisp + + render_kwargs_test = { + k: render_kwargs_train[k] for k in render_kwargs_train} + render_kwargs_test['perturb'] = False + render_kwargs_test['raw_noise_std'] = 0. + + return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, learned_codes_dict, \ + AudNet_state, optimizer_aud_state, AudAttNet_state + + +def raw2outputs(raw, z_vals, rays_d, bc_rgb, raw_noise_std=0, white_bkgd=False, pytest=False): + """Transforms model's predictions to semantically meaningful values. + Args: + raw: [num_rays, num_samples along ray, 4]. Prediction from model. + z_vals: [num_rays, num_samples along ray]. Integration time. + rays_d: [num_rays, 3]. Direction of each ray. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. + """ + def raw2alpha(raw, dists, act_fn=F.relu): return 1. - \ + torch.exp(-(act_fn(raw)+1e-6)*dists) + + dists = z_vals[..., 1:] - z_vals[..., :-1] + dists = torch.cat([dists, torch.Tensor([1e10], device=z_vals.device).expand( + dists[..., :1].shape)], -1) # [N_rays, N_samples] + + dists = dists * torch.norm(rays_d[..., None, :], dim=-1) + + rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] + rgb = torch.cat((rgb[:, :-1, :], bc_rgb.unsqueeze(1)), dim=1) + noise = 0. + if raw_noise_std > 0.: + noise = torch.randn(raw[..., 3].shape) * raw_noise_std + + # Overwrite randomly sampled data if pytest + if pytest: + np.random.seed(0) + noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std + noise = torch.Tensor(noise) + + alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] + weights = alpha * \ + torch.cumprod( + torch.cat([torch.ones((alpha.shape[0], 1), device=alpha.device), 1.-alpha + 1e-10], -1), -1)[:, :-1] + rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] + + rgb_map_fg = torch.sum(weights[:, :-1, None]*rgb[:, :-1, :], -2) + + depth_map = torch.sum(weights * z_vals, -1) + disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), + depth_map / torch.sum(weights, -1)) + acc_map = torch.sum(weights, -1) + + if white_bkgd: + rgb_map = rgb_map + (1.-acc_map[..., None]) + + return rgb_map, disp_map, acc_map, weights, depth_map, rgb_map_fg + + +def render_rays(ray_batch, + bc_rgb, + aud_para, + network_fn, + network_query_fn, + N_samples, + retraw=False, + lindisp=False, + perturb=0., + N_importance=0, + network_fine=None, + white_bkgd=False, + raw_noise_std=0., + verbose=False, + pytest=False): + """Volumetric rendering. + Args: + ray_batch: array of shape [batch_size, ...]. All information necessary + for sampling along a ray, including: ray origin, ray direction, min + dist, max dist, and unit-magnitude viewing direction. + network_fn: function. Model for predicting RGB and density at each point + in space. + network_query_fn: function used for passing queries to network_fn. + N_samples: int. Number of different times to sample along each ray. + retraw: bool. If True, include model's raw, unprocessed predictions. + lindisp: bool. If True, sample linearly in inverse depth rather than in depth. + perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified + random points in time. + N_importance: int. Number of additional times to sample along each ray. + These samples are only passed to network_fine. + network_fine: "fine" network with same spec as network_fn. + white_bkgd: bool. If True, assume a white background. + raw_noise_std: ... + verbose: bool. If True, print more debugging info. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. + disp_map: [num_rays]. Disparity map. 1 / depth. + acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. + raw: [num_rays, num_samples, 4]. Raw predictions from model. + rgb0: See rgb_map. Output for coarse model. + disp0: See disp_map. Output for coarse model. + acc0: See acc_map. Output for coarse model. + z_std: [num_rays]. Standard deviation of distances along ray for each + sample. + """ + N_rays = ray_batch.shape[0] + rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each + viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None + bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2]) + near, far = bounds[..., 0], bounds[..., 1] # [-1,1] + + t_vals = torch.linspace(0., 1., steps=N_samples, device=rays_o.device) + if not lindisp: + z_vals = near * (1.-t_vals) + far * (t_vals) + else: + z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) + + z_vals = z_vals.expand([N_rays, N_samples]) + + if perturb > 0.: + # get intervals between samples + mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) + upper = torch.cat([mids, z_vals[..., -1:]], -1) + lower = torch.cat([z_vals[..., :1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand(z_vals.shape, device=rays_o.device) + + # Pytest, overwrite u with numpy's fixed random numbers + if pytest: + np.random.seed(0) + t_rand = np.random.rand(*list(z_vals.shape)) + t_rand = torch.Tensor(t_rand).to(rays_o.device) + t_rand[..., -1] = 1.0 + z_vals = lower + (upper - lower) * t_rand + pts = rays_o[..., None, :] + rays_d[..., None, :] * \ + z_vals[..., :, None] # [N_rays, N_samples, 3] + + +# raw = run_network(pts) + raw = network_query_fn(pts, viewdirs, aud_para, network_fn) + rgb_map, disp_map, acc_map, weights, depth_map, rgb_map_fg = raw2outputs( + raw, z_vals, rays_d, bc_rgb, raw_noise_std, white_bkgd, pytest=pytest) + + if N_importance > 0: + + rgb_map_0, disp_map_0, acc_map_0, last_weight_0, rgb_map_fg_0 = \ + rgb_map, disp_map, acc_map, weights[..., -1], rgb_map_fg + + z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) + z_samples = sample_pdf( + z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.), pytest=pytest) + z_samples = z_samples.detach() + + z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) + pts = rays_o[..., None, :] + rays_d[..., None, :] * \ + z_vals[..., :, None] + + run_fn = network_fn if network_fine is None else network_fine + raw = network_query_fn(pts, viewdirs, aud_para, run_fn) + + rgb_map, disp_map, acc_map, weights, depth_map, rgb_map_fg = raw2outputs( + raw, z_vals, rays_d, bc_rgb, raw_noise_std, white_bkgd, pytest=pytest) + + ret = {'rgb_map': rgb_map, 'disp_map': disp_map, + 'acc_map': acc_map, 'rgb_map_fg': rgb_map_fg} + if retraw: + ret['raw'] = raw + if N_importance > 0: + ret['rgb0'] = rgb_map_0 + ret['disp0'] = disp_map_0 + ret['acc0'] = acc_map_0 + ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] + ret['last_weight'] = weights[..., -1] + ret['last_weight0'] = last_weight_0 + ret['rgb_map_fg0'] = rgb_map_fg_0 + + for k in ret: + if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: + print(f"! [Numerical Error] {k} contains nan or inf.") + + return ret + + +def config_parser(): + + import configargparse + parser = configargparse.ArgumentParser() + parser.add_argument('--config', is_config_file=True, + help='config file path') + parser.add_argument("--expname", type=str, + help='experiment name') + parser.add_argument("--basedir", type=str, default='./logs/', + help='where to store ckpts and logs') + parser.add_argument("--datadir", type=str, default='./data/llff/fern', + help='input data directory') + + # training options + parser.add_argument("--netdepth", type=int, default=8, + help='layers in network') + parser.add_argument("--netwidth", type=int, default=256, + help='channels per layer') + parser.add_argument("--netdepth_fine", type=int, default=8, + help='layers in fine network') + parser.add_argument("--netwidth_fine", type=int, default=256, + help='channels per layer in fine network') + parser.add_argument("--N_rand", type=int, default=1024, + help='batch size (number of random rays per gradient step)') + parser.add_argument("--lrate", type=float, default=5e-4, + help='learning rate') + parser.add_argument("--lrate_decay", type=int, default=250, + help='exponential learning rate decay (in 1000 steps)') + parser.add_argument("--chunk", type=int, default=1024, + help='number of rays processed in parallel, decrease if running out of memory') + parser.add_argument("--netchunk", type=int, default=1024*64, + help='number of pts sent through network in parallel, decrease if running out of memory') + parser.add_argument("--no_batching", action='store_false', + help='only take random rays from 1 image at a time') + parser.add_argument("--no_reload", action='store_true', + help='do not reload weights from saved ckpt') + parser.add_argument("--ft_path", type=str, default=None, + help='specific weights npy file to reload for coarse network') + parser.add_argument("--N_iters", type=int, default=400000, + help='number of iterations') + + # rendering options + parser.add_argument("--N_samples", type=int, default=64, + help='number of coarse samples per ray') + parser.add_argument("--N_importance", type=int, default=128, + help='number of additional fine samples per ray') + parser.add_argument("--perturb", type=float, default=1., + help='set to 0. for no jitter, 1. for jitter') + parser.add_argument("--use_viewdirs", action='store_false', + help='use full 5D input instead of 3D') + parser.add_argument("--i_embed", type=int, default=0, + help='set 0 for default positional encoding, -1 for none') + parser.add_argument("--multires", type=int, default=10, + help='log2 of max freq for positional encoding (3D location)') + parser.add_argument("--multires_views", type=int, default=4, + help='log2 of max freq for positional encoding (2D direction)') + parser.add_argument("--raw_noise_std", type=float, default=0., + help='std dev of noise added to regularize sigma_a output, 1e0 recommended') + + parser.add_argument("--render_only", action='store_true', + help='do not optimize, reload weights and render out render_poses path') + parser.add_argument("--render_test", action='store_true', + help='render the test set instead of render_poses path') + parser.add_argument("--render_factor", type=int, default=0, + help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') + + # training options + parser.add_argument("--precrop_iters", type=int, default=0, + help='number of steps to train on central crops') + parser.add_argument("--precrop_frac", type=float, + default=.5, help='fraction of img taken for central crops') + + # dataset options + parser.add_argument("--dataset_type", type=str, default='audface', + help='options: llff / blender / deepvoxels') + parser.add_argument("--testskip", type=int, default=1, + help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') + + # deepvoxels flags + parser.add_argument("--shape", type=str, default='greek', + help='options : armchair / cube / greek / vase') + + # blender flags + parser.add_argument("--white_bkgd", action='store_false', + help='set to render synthetic data on a white bkgd (always use for dvoxels)') + parser.add_argument("--half_res", action='store_true', + help='load blender synthetic data at 400x400 instead of 800x800') + + # face flags + parser.add_argument("--with_test", type=int, default=0, + help='whether to test') + parser.add_argument("--dim_aud", type=int, default=64, + help='dimension of audio features for NeRF') + parser.add_argument("--dim_aud_body", type=int, default=64, + help='dimension of audio features for NeRF') + parser.add_argument("--sample_rate", type=float, default=0.95, + help="sample rate in a bounding box") + parser.add_argument("--near", type=float, default=0.3, + help="near sampling plane") + parser.add_argument("--far", type=float, default=0.9, + help="far sampling plane") + parser.add_argument("--test_pose_file", type=str, default='transforms_train.json', + help='test pose file') + parser.add_argument("--aud_file", type=str, default='aud.npy', + help='test audio deepspeech file') + parser.add_argument("--win_size", type=int, default=16, + help="windows size of audio feature") + parser.add_argument("--smo_size", type=int, default=8, + help="window size for smoothing audio features") + parser.add_argument('--test_size', type=int, default=-1, + help='test size') + parser.add_argument('--aud_start', type=int, default=0, + help='test audio start pos') + parser.add_argument('--test_save_folder', type=str, default='test_aud_rst', + help='folder to store test result') + + # llff flags + parser.add_argument("--factor", type=int, default=8, + help='downsample factor for LLFF images') + parser.add_argument("--no_ndc", action='store_true', + help='do not use normalized device coordinates (set for non-forward facing scenes)') + parser.add_argument("--lindisp", action='store_true', + help='sampling linearly in disparity rather than depth') + parser.add_argument("--spherify", action='store_true', + help='set for spherical 360 scenes') + parser.add_argument("--llffhold", type=int, default=8, + help='will take every 1/N images as LLFF test set, paper uses 8') + + # logging/saving options + parser.add_argument("--i_print", type=int, default=100, + help='frequency of console printout and metric loggin') + parser.add_argument("--i_img", type=int, default=500, + help='frequency of tensorboard image logging') + parser.add_argument("--i_weights", type=int, default=10000, + help='frequency of weight ckpt saving') + parser.add_argument("--i_testset", type=int, default=10000, + help='frequency of testset saving') + parser.add_argument("--i_video", type=int, default=50000, + help='frequency of render_poses video saving') + + return parser + + +def train(): + + parser = config_parser() + args = parser.parse_args() + + # Load data + if args.with_test == 1: + poses, auds, bc_img, hwfcxy, aud_ids, torso_pose = \ + load_test_data(args.datadir, args.aud_file, + args.test_pose_file, args.testskip, args.test_size, args.aud_start) + torso_pose = torch.as_tensor(torso_pose).to(device_torso).float() + com_images = np.zeros(1) + else: + com_images, poses, auds, bc_img, hwfcxy, sample_rects, \ + i_split = load_audface_data(args.datadir, args.testskip) + + if args.with_test == 0: + i_train, i_val = i_split + + near = args.near + far = args.far + + # Cast intrinsics to right types + H, W, focal, cx, cy = hwfcxy + H, W = int(H), int(W) + hwf = [H, W, focal] + hwfcxy = [H, W, focal, cx, cy] + + # Create log dir and copy the config file + basedir = args.basedir + expname = args.expname + os.makedirs(os.path.join(basedir, expname), exist_ok=True) + f = os.path.join(basedir, expname, 'args.txt') + with open(f, 'w') as file: + for arg in sorted(vars(args)): + attr = getattr(args, arg) + file.write('{} = {}\n'.format(arg, attr)) + if args.config is not None: + f = os.path.join(basedir, expname, 'config.txt') + with open(f, 'w') as file: + file.write(open(args.config, 'r').read()) + + # Create nerf model + render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, \ + learned_codes, AudNet_state, optimizer_aud_state, AudAttNet_state = create_nerf( + args, 'head.tar', args.dim_aud, device, True) + global_step = start + + AudNet = AudioNet(args.dim_aud, args.win_size).to(device) + AudAttNet = AudioAttNet().to(device) + optimizer_Aud = torch.optim.Adam( + params=list(AudNet.parameters()), lr=args.lrate, betas=(0.9, 0.999)) + + if AudNet_state is not None: + AudNet.load_state_dict(AudNet_state) + if AudAttNet_state is not None: + print('load audattnet') + AudAttNet.load_state_dict(AudAttNet_state) + if optimizer_aud_state is not None: + optimizer_Aud.load_state_dict(optimizer_aud_state) + bds_dict = { + 'near': near, + 'far': far, + } + render_kwargs_train.update(bds_dict) + render_kwargs_test.update(bds_dict) + + # Move training data to GPU + bc_img = torch.Tensor(bc_img).to(device).float()/255.0 + poses = torch.Tensor(poses).to(device).float() + auds = torch.Tensor(auds).to(device).float() + + num_frames = com_images.shape[0] + + embed_fn, input_ch = get_embedder(3, 0) + dim_torso_signal = args.dim_aud_body + 2*input_ch + # Create torso nerf model + render_kwargs_train_torso, render_kwargs_test_torso, start, grad_vars_torso, optimizer_torso, \ + learned_codes_torso, AudNet_state_torso, optimizer_aud_state_torso, _ = create_nerf( + args, 'body.tar', dim_torso_signal, device_torso) + global_step = start + + AudNet_torso = AudioNet(args.dim_aud_body, args.win_size).to(device_torso) + optimizer_Aud_torso = torch.optim.Adam( + params=list(AudNet_torso.parameters()), lr=args.lrate, betas=(0.9, 0.999)) + + if AudNet_state_torso is not None: + AudNet_torso.load_state_dict(AudNet_state_torso) + if optimizer_aud_state_torso is not None: + optimizer_Aud_torso.load_state_dict(optimizer_aud_state_torso) + bds_dict = { + 'near': near, + 'far': far, + } + render_kwargs_train_torso.update(bds_dict) + render_kwargs_test_torso.update(bds_dict) + + if args.with_test: + print('RENDER ONLY') + with torch.no_grad(): + testsavedir = os.path.join(basedir, expname, args.test_save_folder) + os.makedirs(testsavedir, exist_ok=True) + print('test poses shape', poses.shape) + smo_half_win = int(args.smo_size / 2) + auds_val = [] + for i in range(poses.shape[0]): + left_i = i - smo_half_win + right_i = i + smo_half_win + pad_left, pad_right = 0, 0 + if left_i < 0: + pad_left = -left_i + left_i = 0 + if right_i > poses.shape[0]: + pad_right = right_i - poses.shape[0] + right_i = poses.shape[0] + auds_win = auds[left_i:right_i] + if pad_left > 0: + auds_win = torch.cat( + (torch.zeros_like(auds_win)[:pad_left], auds_win), dim=0) + if pad_right > 0: + auds_win = torch.cat( + (auds_win, torch.zeros_like(auds_win)[:pad_right]), dim=0) + auds_win = AudNet(auds_win) + aud_smo = AudAttNet(auds_win) + auds_val.append(aud_smo) + auds_val = torch.stack(auds_val, 0) + + adjust_poses = poses.clone() + adjust_poses_torso = poses.clone() + + et = pose_to_euler_trans(adjust_poses_torso) + embed_et = torch.cat( + (embed_fn(et[:, :3]), embed_fn(et[:, 3:])), dim=-1).to(device_torso) + signal = torch.cat((auds_val[..., :args.dim_aud_body].to( + device_torso), embed_et.squeeze()), dim=-1) + t_start = time.time() + vid_out = cv2.VideoWriter(os.path.join(testsavedir, 'result.avi'), + cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 25, (W, H)) + for j in range(poses.shape[0]): + rgbs, disps, last_weights, rgb_fgs = \ + render_path(adjust_poses[j:j+1], auds_val[j:j+1], + bc_img, hwfcxy, args.chunk, render_kwargs_test) + rgbs_torso, disps_torso, last_weights_torso, rgb_fgs_torso = \ + render_path(torso_pose.unsqueeze( + 0), signal[j:j+1], bc_img.to(device_torso), hwfcxy, args.chunk, render_kwargs_test_torso) + rgbs_com = rgbs*last_weights_torso[..., None] + rgb_fgs_torso + rgb8 = to8b(rgbs_com[0]) + vid_out.write(rgb8[:, :, ::-1]) + filename = os.path.join( + testsavedir, str(aud_ids[j]) + '.jpg') + imageio.imwrite(filename, rgb8) + print('finished render', j) + print('finished render in', time.time()-t_start) + vid_out.release() + return + + N_rand = args.N_rand + use_batching = not args.no_batching + if use_batching: + # For random ray batching + print('get rays') + rays = np.stack([get_rays_np(H, W, focal, p, cx, cy) + for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] + print('done, concats') + # [N, ro+rd+rgb, H, W, 3] + rays_rgb = np.concatenate([rays, com_images[:, None]], 1) + # [N, H, W, ro+rd+rgb, 3] + rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) + rays_rgb = np.stack([rays_rgb[i] + for i in i_train], 0) # train images only + # [(N-1)*H*W, ro+rd+rgb, 3] + rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) + rays_rgb = rays_rgb.astype(np.float32) + print('shuffle rays') + np.random.shuffle(rays_rgb) + + print('done') + i_batch = 0 + + if use_batching: + rays_rgb = torch.Tensor(rays_rgb).to(device) + + N_iters = args.N_iters + 1 + print('Begin') + print('TRAIN views are', i_train) + print('VAL views are', i_val) + + start = start + 1 + for i in trange(start, N_iters): + time0 = time.time() + + # Sample random ray batch + if use_batching: + # Random over all images + batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] + batch = torch.transpose(batch, 0, 1) + batch_rays, target_s = batch[:2], batch[2] + + i_batch += N_rand + if i_batch >= rays_rgb.shape[0]: + print("Shuffle data after an epoch!") + rand_idx = torch.randperm(rays_rgb.shape[0]) + rays_rgb = rays_rgb[rand_idx] + i_batch = 0 + + else: + # Random from one image + img_i = np.random.choice(i_train) + target_com = torch.as_tensor(imageio.imread( + com_images[img_i])).to(device).float()/255.0 + pose = poses[img_i, :3, :4] + pose_torso = poses[0, :3, :4].to(device_torso) + rect = sample_rects[img_i] + aud = auds[img_i] + + smo_half_win = int(args.smo_size/2) + left_i = img_i - smo_half_win + right_i = img_i + smo_half_win + pad_left, pad_right = 0, 0 + if left_i < 0: + pad_left = -left_i + left_i = 0 + if right_i > i_train.shape[0]: + pad_right = right_i-i_train.shape[0] + right_i = i_train.shape[0] + auds_win = auds[left_i:right_i] + if pad_left > 0: + auds_win = torch.cat( + (torch.zeros_like(auds_win)[:pad_left], auds_win), dim=0) + if pad_right > 0: + auds_win = torch.cat( + (auds_win, torch.zeros_like(auds_win)[:pad_right]), dim=0) + auds_win = AudNet(auds_win) + aud_smo = AudAttNet(auds_win) + aud_smo_torso = aud_smo.to(device_torso)[..., :args.dim_aud_body] + + et = pose_to_euler_trans(poses[img_i].unsqueeze(0)) + embed_et = torch.cat( + (embed_fn(et[:, :3]), embed_fn(et[:, 3:])), dim=1).to(device_torso) + signal = torch.cat((aud_smo_torso, embed_et.squeeze()), dim=-1) + if N_rand is not None: + rays_o, rays_d = get_rays( + H, W, focal, pose, cx, cy, device) # (H, W, 3), (H, W, 3) + rays_o_torso, rays_d_torso = get_rays( + H, W, focal, pose_torso, cx, cy, device_torso) + + if i < args.precrop_iters: + dH = int(H//2 * args.precrop_frac) + dW = int(W//2 * args.precrop_frac) + coords = torch.stack( + torch.meshgrid( + torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), + torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) + ), -1) + if i == start: + print( + f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") + else: + coords = torch.stack(torch.meshgrid(torch.linspace( + 0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) + + coords = torch.reshape(coords, [-1, 2]) # (H * W, 2) + if args.sample_rate > 0: + rect = [0, H/2, W, H/2] + rect_inds = (coords[:, 0] >= rect[0]) & ( + coords[:, 0] <= rect[0] + rect[2]) & ( + coords[:, 1] >= rect[1]) & ( + coords[:, 1] <= rect[1] + rect[3]) + coords_rect = coords[rect_inds] + coords_norect = coords[~rect_inds] + rect_num = int(N_rand*float(rect[2])*rect[3]/H/W) + norect_num = N_rand - rect_num + select_inds_rect = np.random.choice( + coords_rect.shape[0], size=[rect_num], replace=False) # (N_rand,) + # (N_rand, 2) + select_coords_rect = coords_rect[select_inds_rect].long() + select_inds_norect = np.random.choice( + coords_norect.shape[0], size=[norect_num], replace=False) # (N_rand,) + # (N_rand, 2) + select_coords_norect = coords_norect[select_inds_norect].long( + ) + select_coords = torch.cat( + (select_coords_norect, select_coords_rect), dim=0) + + else: + select_inds = np.random.choice( + coords.shape[0], size=[N_rand], replace=False) # (N_rand,) + select_coords = coords[select_inds].long() + norect_num = 0 + + rays_o = rays_o[select_coords[:, 0], + select_coords[:, 1]] # (N_rand, 3) + rays_d = rays_d[select_coords[:, 0], + select_coords[:, 1]] # (N_rand, 3) + batch_rays = torch.stack([rays_o, rays_d], 0) + bc_rgb = bc_img[select_coords[:, 0], + select_coords[:, 1]] + + rays_o_torso = rays_o_torso[select_coords[:, 0], + select_coords[:, 1]] # (N_rand, 3) + rays_d_torso = rays_d_torso[select_coords[:, 0], + select_coords[:, 1]] # (N_rand, 3) + batch_rays_torso = torch.stack([rays_o_torso, rays_d_torso], 0) + bc_rgb = bc_img[select_coords[:, 0], + select_coords[:, 1]] + bc_rgb_torso = bc_rgb.to(device_torso) + + target_s_com = target_com[select_coords[:, 0], + select_coords[:, 1]] # (N_rand, 3) + + ##### Core optimization loop ##### + rgb, disp, acc, last_weight, rgb_fg, extras = \ + render_dynamic_face(H, W, focal, cx, cy, chunk=args.chunk, rays=batch_rays, + aud_para=aud_smo, bc_rgb=bc_rgb, + verbose=i < 10, retraw=True, + ** render_kwargs_train) + rgb_torso, disp_torso, acc_torso, last_weight_torso, rgb_fg_torso, extras_torso = \ + render_dynamic_face(H, W, focal, cx, cy, chunk=args.chunk, rays=batch_rays_torso, + aud_para=signal, bc_rgb=bc_rgb_torso, + verbose=i < 10, retraw=True, + **render_kwargs_train_torso) + rgb_com = rgb * \ + last_weight_torso.to(device)[..., None] + rgb_fg_torso.to(device) + + optimizer_torso.zero_grad() + img_loss_com = img2mse(rgb_com, target_s_com) + trans = extras['raw'][..., -1] + split_weight = float(1.0) + loss = img_loss_com + psnr = mse2psnr(img_loss_com) + + if 'rgb0' in extras_torso: + rgb_com0 = extras['rgb0'] * \ + extras_torso['last_weight0'].to( + device)[..., None] + extras_torso['rgb_map_fg0'].to(device) + img_loss0 = img2mse(rgb_com0, target_s_com) + loss = loss + img_loss0 + + loss.backward() + optimizer_torso.step() + + # NOTE: IMPORTANT! + ### update learning rate ### + decay_rate = 0.1 + decay_steps = args.lrate_decay * 1000 + new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) + #print('cur_rate', new_lrate) + for param_group in optimizer.param_groups: + param_group['lr'] = new_lrate + + for param_group in optimizer_Aud.param_groups: + param_group['lr'] = new_lrate + + for param_group in optimizer_torso.param_groups: + param_group['lr'] = new_lrate + + for param_group in optimizer_Aud_torso.param_groups: + param_group['lr'] = new_lrate + ################################ + + dt = time.time()-time0 + + # Rest is logging + if i % args.i_weights == 0: + path = os.path.join(basedir, expname, '{:06d}_head.tar'.format(i)) + torch.save({ + 'global_step': global_step, + 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), + 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), + 'network_audnet_state_dict': AudNet.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'optimizer_aud_state_dict': optimizer_Aud.state_dict(), + 'network_audattnet_state_dict': AudAttNet.state_dict(), + }, path) + + path = os.path.join(basedir, expname, '{:06d}_body.tar'.format(i)) + torch.save({ + 'global_step': global_step, + 'network_fn_state_dict': render_kwargs_train_torso['network_fn'].state_dict(), + 'network_fine_state_dict': render_kwargs_train_torso['network_fine'].state_dict(), + 'network_audnet_state_dict': AudNet_torso.state_dict(), + 'optimizer_state_dict': optimizer_torso.state_dict(), + 'optimizer_aud_state_dict': optimizer_Aud_torso.state_dict(), + }, path) + print('Saved checkpoints at', path) + + if i % args.i_testset == 0 and i > 0: + testsavedir = os.path.join( + basedir, expname, 'testset_{:06d}'.format(i)) + os.makedirs(testsavedir, exist_ok=True) + print('test poses shape', poses[i_val].shape) + + aud_torso = AudNet( + auds[i_val])[..., :args.dim_aud_body].to(device_torso) + et = pose_to_euler_trans(poses[i_val]) + embed_et = torch.cat( + (embed_fn(et[:, :3]), embed_fn(et[:, 3:])), dim=1).to(device_torso) + signal = torch.cat((aud_torso, embed_et.squeeze()), dim=-1) + + auds_val = AudNet(auds[i_val]) + with torch.no_grad(): + for j in range(auds_val.shape[0]): + rgbs, disps, last_weights, rgb_fgs = \ + render_path(poses[i_val][j:j+1], auds_val[j:j+1], + bc_img, hwfcxy, args.chunk, render_kwargs_test) + rgbs_torso, disps_torso, last_weights_torso, rgb_fgs_torso = \ + render_path(poses[0].to(device_torso).unsqueeze(0), + signal[j:j+1], bc_img.to( + device_torso), hwfcxy, args.chunk, render_kwargs_test_torso) + rgbs_com = rgbs * \ + last_weights_torso[..., None] + rgb_fgs_torso + rgb8 = to8b(rgbs_com[0]) + filename = os.path.join( + testsavedir, '{:03d}.jpg'.format(j)) + imageio.imwrite(filename, rgb8) + print('Saved test set') + + if i % args.i_print == 0: + tqdm.write( + f"[TRAIN] Iter: {i} Loss: {img_loss_com.item()} PSNR: {psnr.item()}") + + global_step += 1 + + +if __name__ == '__main__': + torch.set_default_tensor_type('torch.cuda.FloatTensor') + + train() diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf_helpers.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf_helpers.py new file mode 100644 index 00000000..e47025fb --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/ADNeRFmaster/run_nerf_helpers.py @@ -0,0 +1,454 @@ +import numpy as np +import torch.nn.functional as F +import torch.nn as nn +import torch +torch.autograd.set_detect_anomaly(True) + +# TODO: remove this dependency + + +# Misc +def img2mse(x, y, num=0): + if num > 0: + return torch.mean((x[num] - y[num]) ** 2) + else: + return torch.mean((x - y) ** 2) + + +def mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.])) + + +def to8b(x): return (255*np.clip(x, 0, 1)).astype(np.uint8) + + +# Positional encoding (section 5.1) +class Embedder: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.create_embedding_fn() + + def create_embedding_fn(self): + embed_fns = [] + d = self.kwargs['input_dims'] + device = self.kwargs['device'] + out_dim = 0 + if self.kwargs['include_input']: + embed_fns.append(lambda x: x) + out_dim += d + + max_freq = self.kwargs['max_freq_log2'] + N_freqs = self.kwargs['num_freqs'] + + if self.kwargs['log_sampling']: + freq_bands = 2.**torch.linspace(0., + max_freq, steps=N_freqs, device=device) + else: + freq_bands = torch.linspace( + 2.**0., 2.**max_freq, steps=N_freqs, device=device) + + for freq in freq_bands: + for p_fn in self.kwargs['periodic_fns']: + embed_fns.append(lambda x, p_fn=p_fn, + freq=freq: p_fn(x * freq)) + out_dim += d + + self.embed_fns = embed_fns + self.out_dim = out_dim + + def embed(self, inputs): + return torch.cat([fn(inputs) for fn in self.embed_fns], -1) + + +def get_embedder(multires, i=0, device=torch.device('cuda', 0)): + if i == -1: + return nn.Identity(), 3 + + embed_kwargs = { + 'include_input': True, + 'input_dims': 3, + 'max_freq_log2': multires-1, + 'num_freqs': multires, + 'log_sampling': True, + 'device': device, + 'periodic_fns': [torch.sin, torch.cos], + } + + embedder_obj = Embedder(**embed_kwargs) + def embed(x, eo=embedder_obj): return eo.embed(x) + return embed, embedder_obj.out_dim + + +# Audio feature extractor +class AudioAttNet(nn.Module): + def __init__(self, dim_aud=32, seq_len=8): + super(AudioAttNet, self).__init__() + self.seq_len = seq_len + self.dim_aud = dim_aud + self.attentionConvNet = nn.Sequential( # b x subspace_dim x seq_len + nn.Conv1d(self.dim_aud, 16, kernel_size=3, + stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(16, 8, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(8, 4, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(4, 2, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True), + nn.Conv1d(2, 1, kernel_size=3, stride=1, padding=1, bias=True), + nn.LeakyReLU(0.02, True) + ) + self.attentionNet = nn.Sequential( + nn.Linear(in_features=self.seq_len, + out_features=self.seq_len, bias=True), + nn.Softmax(dim=1) + ) + + def forward(self, x): + y = x[..., :self.dim_aud].permute(1, 0).unsqueeze( + 0) # 2 x subspace_dim x seq_len + y = self.attentionConvNet(y) + y = self.attentionNet(y.view(1, self.seq_len)).view(self.seq_len, 1) + #print(y.view(-1).data) + return torch.sum(y*x, dim=0) +# Model + +# Audio feature extractor + + +# Audio feature extractor +class AudioNet(nn.Module): + def __init__(self, dim_aud=76, win_size=16): + super(AudioNet, self).__init__() + self.win_size = win_size + self.dim_aud = dim_aud + self.encoder_conv = nn.Sequential( # n x 29 x 16 + nn.Conv1d(29, 32, kernel_size=3, stride=2, + padding=1, bias=True), # n x 32 x 8 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 32, kernel_size=3, stride=2, + padding=1, bias=True), # n x 32 x 4 + nn.LeakyReLU(0.02, True), + nn.Conv1d(32, 64, kernel_size=3, stride=2, + padding=1, bias=True), # n x 64 x 2 + nn.LeakyReLU(0.02, True), + nn.Conv1d(64, 64, kernel_size=3, stride=2, + padding=1, bias=True), # n x 64 x 1 + nn.LeakyReLU(0.02, True), + ) + self.encoder_fc1 = nn.Sequential( + nn.Linear(64, 64), + nn.LeakyReLU(0.02, True), + nn.Linear(64, dim_aud), + ) + + def forward(self, x): + half_w = int(self.win_size/2) + x = x[:, 8-half_w:8+half_w, :].permute(0, 2, 1) + x = self.encoder_conv(x).squeeze(-1) + x = self.encoder_fc1(x).squeeze() + return x +# Model + + +class FaceNeRF(nn.Module): + def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, dim_aud=32, + output_ch=4, skips=[4], use_viewdirs=False): + """ + """ + super(FaceNeRF, self).__init__() + self.D = D + self.W = W + self.input_ch = input_ch + self.input_ch_views = input_ch_views + self.dim_aud = dim_aud + self.skips = skips + self.use_viewdirs = use_viewdirs + + input_ch_all = input_ch + dim_aud + self.pts_linears = nn.ModuleList( + [nn.Linear(input_ch_all, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch_all, W) for i in range(D-1)]) + + # Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) + # self.views_linears = nn.ModuleList( + # [nn.Linear(input_ch_views + W, W//2)]) + + # Implementation according to the paper + self.views_linears = nn.ModuleList( + [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//4)]) + + if use_viewdirs: + self.feature_linear = nn.Linear(W, W) + self.alpha_linear = nn.Linear(W, 1) + self.rgb_linear = nn.Linear(W//2, 3) + else: + self.output_linear = nn.Linear(W, output_ch) + + def forward(self, x): + input_pts, input_views = torch.split( + x, [self.input_ch+self.dim_aud, self.input_ch_views], dim=-1) + h = input_pts + for i, l in enumerate(self.pts_linears): + h = self.pts_linears[i](h) + h = F.relu(h) + if i in self.skips: + h = torch.cat([input_pts, h], -1) + + if self.use_viewdirs: + alpha = self.alpha_linear(h) + feature = h # self.feature_linear(h) + h = torch.cat([feature, input_views], -1) + + for i, l in enumerate(self.views_linears): + h = self.views_linears[i](h) + h = F.relu(h) + + rgb = self.rgb_linear(h) + outputs = torch.cat([rgb, alpha], -1) + else: + outputs = self.output_linear(h) + + return outputs + + def load_weights_from_keras(self, weights): + assert self.use_viewdirs, "Not implemented if use_viewdirs=False" + + # Load pts_linears + for i in range(self.D): + idx_pts_linears = 2 * i + self.pts_linears[i].weight.data = torch.from_numpy( + np.transpose(weights[idx_pts_linears])) + self.pts_linears[i].bias.data = torch.from_numpy( + np.transpose(weights[idx_pts_linears+1])) + + # Load feature_linear + idx_feature_linear = 2 * self.D + self.feature_linear.weight.data = torch.from_numpy( + np.transpose(weights[idx_feature_linear])) + self.feature_linear.bias.data = torch.from_numpy( + np.transpose(weights[idx_feature_linear+1])) + + # Load views_linears + idx_views_linears = 2 * self.D + 2 + self.views_linears[0].weight.data = torch.from_numpy( + np.transpose(weights[idx_views_linears])) + self.views_linears[0].bias.data = torch.from_numpy( + np.transpose(weights[idx_views_linears+1])) + + # Load rgb_linear + idx_rbg_linear = 2 * self.D + 4 + self.rgb_linear.weight.data = torch.from_numpy( + np.transpose(weights[idx_rbg_linear])) + self.rgb_linear.bias.data = torch.from_numpy( + np.transpose(weights[idx_rbg_linear+1])) + + # Load alpha_linear + idx_alpha_linear = 2 * self.D + 6 + self.alpha_linear.weight.data = torch.from_numpy( + np.transpose(weights[idx_alpha_linear])) + self.alpha_linear.bias.data = torch.from_numpy( + np.transpose(weights[idx_alpha_linear+1])) + + +# Model +class NeRF(nn.Module): + def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=False): + """ + """ + super(NeRF, self).__init__() + self.D = D + self.W = W + self.input_ch = input_ch + self.input_ch_views = input_ch_views + self.skips = skips + self.use_viewdirs = use_viewdirs + + self.pts_linears = nn.ModuleList( + [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)]) + + # Implementation according to the official code release (https://github.com/bmild/nerf/blob/master/run_nerf_helpers.py#L104-L105) + self.views_linears = nn.ModuleList( + [nn.Linear(input_ch_views + W, W//2)]) + + # Implementation according to the paper + # self.views_linears = nn.ModuleList( + # [nn.Linear(input_ch_views + W, W//2)] + [nn.Linear(W//2, W//2) for i in range(D//2)]) + + if use_viewdirs: + self.feature_linear = nn.Linear(W, W) + self.alpha_linear = nn.Linear(W, 1) + self.rgb_linear = nn.Linear(W//2, 3) + else: + self.output_linear = nn.Linear(W, output_ch) + + def forward(self, x): + input_pts, input_views = torch.split( + x, [self.input_ch, self.input_ch_views], dim=-1) + h = input_pts + for i, l in enumerate(self.pts_linears): + h = self.pts_linears[i](h) + h = F.relu(h) + if i in self.skips: + h = torch.cat([input_pts, h], -1) + + if self.use_viewdirs: + alpha = self.alpha_linear(h) + feature = self.feature_linear(h) + h = torch.cat([feature, input_views], -1) + + for i, l in enumerate(self.views_linears): + h = self.views_linears[i](h) + h = F.relu(h) + + rgb = self.rgb_linear(h) + outputs = torch.cat([rgb, alpha], -1) + else: + outputs = self.output_linear(h) + + return outputs + + def load_weights_from_keras(self, weights): + assert self.use_viewdirs, "Not implemented if use_viewdirs=False" + + # Load pts_linears + for i in range(self.D): + idx_pts_linears = 2 * i + self.pts_linears[i].weight.data = torch.from_numpy( + np.transpose(weights[idx_pts_linears])) + self.pts_linears[i].bias.data = torch.from_numpy( + np.transpose(weights[idx_pts_linears+1])) + + # Load feature_linear + idx_feature_linear = 2 * self.D + self.feature_linear.weight.data = torch.from_numpy( + np.transpose(weights[idx_feature_linear])) + self.feature_linear.bias.data = torch.from_numpy( + np.transpose(weights[idx_feature_linear+1])) + + # Load views_linears + idx_views_linears = 2 * self.D + 2 + self.views_linears[0].weight.data = torch.from_numpy( + np.transpose(weights[idx_views_linears])) + self.views_linears[0].bias.data = torch.from_numpy( + np.transpose(weights[idx_views_linears+1])) + + # Load rgb_linear + idx_rbg_linear = 2 * self.D + 4 + self.rgb_linear.weight.data = torch.from_numpy( + np.transpose(weights[idx_rbg_linear])) + self.rgb_linear.bias.data = torch.from_numpy( + np.transpose(weights[idx_rbg_linear+1])) + + # Load alpha_linear + idx_alpha_linear = 2 * self.D + 6 + self.alpha_linear.weight.data = torch.from_numpy( + np.transpose(weights[idx_alpha_linear])) + self.alpha_linear.bias.data = torch.from_numpy( + np.transpose(weights[idx_alpha_linear+1])) + + +# Ray helpers +def get_rays(H, W, focal, c2w, cx=None, cy=None, device_cur=torch.device('cuda', 0)): + # pytorch's meshgrid has indexing='ij' + i, j = torch.meshgrid(torch.linspace(0, W-1, W, device=device_cur, dtype=torch.float32), + torch.linspace(0, H-1, H, device=device_cur, dtype=torch.float32)) + i = i.t() + j = j.t() + if cx is None: + cx = W*.5 + if cy is None: + cy = H*.5 + dirs = torch.stack( + [(i-cx)/focal, -(j-cy)/focal, -torch.ones_like(i)], -1) + # Rotate ray directions from camera frame to the world frame + # dot product, equals to: [c2w.dot(dir) for dir in dirs] + rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) + # Translate camera frame's origin to the world frame. It is the origin of all rays. + rays_o = c2w[:3, -1].expand(rays_d.shape) + return rays_o, rays_d + + +def get_rays_np(H, W, focal, c2w, cx=None, cy=None): + if cx is None: + cx = W*.5 + if cy is None: + cy = H*.5 + i, j = np.meshgrid(np.arange(W, dtype=np.float32), + np.arange(H, dtype=np.float32), indexing='xy') + dirs = np.stack([(i-cx)/focal, -(j-cy)/focal, -np.ones_like(i)], -1) + # Rotate ray directions from camera frame to the world frame + # dot product, equals to: [c2w.dot(dir) for dir in dirs] + rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) + # Translate camera frame's origin to the world frame. It is the origin of all rays. + rays_o = np.broadcast_to(c2w[:3, -1], np.shape(rays_d)) + return rays_o, rays_d + + +def ndc_rays(H, W, focal, near, rays_o, rays_d): + # Shift ray origins to near plane + t = -(near + rays_o[..., 2]) / rays_d[..., 2] + rays_o = rays_o + t[..., None] * rays_d + + # Projection + o0 = -1./(W/(2.*focal)) * rays_o[..., 0] / rays_o[..., 2] + o1 = -1./(H/(2.*focal)) * rays_o[..., 1] / rays_o[..., 2] + o2 = 1. + 2. * near / rays_o[..., 2] + + d0 = -1./(W/(2.*focal)) * \ + (rays_d[..., 0]/rays_d[..., 2] - rays_o[..., 0]/rays_o[..., 2]) + d1 = -1./(H/(2.*focal)) * \ + (rays_d[..., 1]/rays_d[..., 2] - rays_o[..., 1]/rays_o[..., 2]) + d2 = -2. * near / rays_o[..., 2] + + rays_o = torch.stack([o0, o1, o2], -1) + rays_d = torch.stack([d0, d1, d2], -1) + + return rays_o, rays_d + + +# Hierarchical sampling (section 5.2) +def sample_pdf(bins, weights, N_samples, det=False, pytest=False): + # Get pdf + weights = weights + 1e-5 # prevent nans + pdf = weights / torch.sum(weights, -1, keepdim=True) + cdf = torch.cumsum(pdf, -1) + # (batch, len(bins)) + cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) + + # Take uniform samples + if det: + u = torch.linspace(0., 1., steps=N_samples, device=cdf.device) + u = u.expand(list(cdf.shape[:-1]) + [N_samples]) + else: + u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=cdf.device) + + # Pytest, overwrite u with numpy's fixed random numbers + if pytest: + np.random.seed(0) + new_shape = list(cdf.shape[:-1]) + [N_samples] + if det: + u = np.linspace(0., 1., N_samples) + u = np.broadcast_to(u, new_shape) + else: + u = np.random.rand(*new_shape) + u = torch.Tensor(u) + + # Invert CDF + u = u.contiguous() + #inds = searchsorted(cdf, u, side='right') + inds = torch.searchsorted(cdf, u, right=True) + below = torch.max(torch.zeros_like(inds-1), inds-1) + above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds) + inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) + + # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) + # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) + matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] + cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) + bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) + + denom = (cdf_g[..., 1]-cdf_g[..., 0]) + denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) + t = (u-cdf_g[..., 0])/denom + samples = bins_g[..., 0] + t * (bins_g[..., 1]-bins_g[..., 0]) + + return samples diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/AD_NeRF_master.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/AD_NeRF_master.py new file mode 100644 index 00000000..b6d59727 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/AD_NeRF_master.py @@ -0,0 +1,114 @@ +import os +import torch +import numpy as np +import imageio +import json +import torch.nn.functional as F +import cv2 + + +def load_audface_data(basedir, testskip=1, test_file=None, aud_file=None, test_size=-1): + if test_file is not None: + with open(os.path.join(basedir, test_file)) as fp: + meta = json.load(fp) + poses = [] + auds = [] + aud_features = np.load(os.path.join(basedir, aud_file)) + cur_id = 0 + for frame in meta['frames'][::testskip]: + poses.append(np.array(frame['transform_matrix'])) + aud_id = cur_id + auds.append(aud_features[aud_id]) + cur_id = cur_id + 1 + if cur_id == aud_features.shape[0] or cur_id == test_size: + break + poses = np.array(poses).astype(np.float32) + auds = np.array(auds).astype(np.float32) + bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg')) + H, W = bc_img.shape[0], bc_img.shape[1] + focal, cx, cy = float(meta['focal_len']), float( + meta['cx']), float(meta['cy']) + return poses, auds, bc_img, [H, W, focal, cx, cy] + + splits = ['train', 'val'] + metas = {} + for s in splits: + with open(os.path.join(basedir, 'transforms_{}.json'.format(s)), 'r') as fp: + metas[s] = json.load(fp) + all_com_imgs = [] + all_poses = [] + all_auds = [] + all_sample_rects = [] + aud_features = np.load(os.path.join(basedir, 'aud.npy')) + counts = [0] + for s in splits: + meta = metas[s] + com_imgs = [] + poses = [] + auds = [] + sample_rects = [] + if s == 'train' or testskip == 0: + skip = 1 + else: + skip = testskip + + for frame in meta['frames'][::skip]: + filename = os.path.join(basedir, 'com_imgs', + str(frame['img_id']) + '.jpg') + com_imgs.append(filename) + poses.append(np.array(frame['transform_matrix'])) + auds.append( + aud_features[min(frame['aud_id'], aud_features.shape[0]-1)]) + sample_rects.append(np.array(frame['face_rect'], dtype=np.int32)) + com_imgs = np.array(com_imgs) + poses = np.array(poses).astype(np.float32) + auds = np.array(auds).astype(np.float32) + counts.append(counts[-1] + com_imgs.shape[0]) + all_com_imgs.append(com_imgs) + all_poses.append(poses) + all_auds.append(auds) + all_sample_rects.append(sample_rects) + i_split = [np.arange(counts[i], counts[i+1]) for i in range(len(splits))] + com_imgs = np.concatenate(all_com_imgs, 0) + poses = np.concatenate(all_poses, 0) + auds = np.concatenate(all_auds, 0) + sample_rects = np.concatenate(all_sample_rects, 0) + + bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg')) + + H, W = bc_img.shape[:2] + focal, cx, cy = float(meta['focal_len']), float( + meta['cx']), float(meta['cy']) + + return com_imgs, poses, auds, bc_img, [H, W, focal, cx, cy], \ + sample_rects, i_split + + +def load_test_data(basedir, aud_file, test_pose_file='transforms_train.json', + testskip=1, test_size=-1, aud_start=0): + with open(os.path.join(basedir, test_pose_file)) as fp: + meta = json.load(fp) + poses = [] + auds = [] + aud_features = np.load(aud_file) + aud_ids = [] + cur_id = 0 + for frame in meta['frames'][::testskip]: + poses.append(np.array(frame['transform_matrix'])) + auds.append( + aud_features[min(aud_start+cur_id, aud_features.shape[0]-1)]) + aud_ids.append(aud_start+cur_id) + cur_id = cur_id + 1 + if cur_id == test_size or cur_id == aud_features.shape[0]: + break + poses = np.array(poses).astype(np.float32) + auds = np.array(auds).astype(np.float32) + bc_img = imageio.imread(os.path.join(basedir, 'bc.jpg')) + H, W = bc_img.shape[0], bc_img.shape[1] + focal, cx, cy = float(meta['focal_len']), float( + meta['cx']), float(meta['cy']) + + with open(os.path.join(basedir, 'transforms_train.json')) as fp: + meta_torso = json.load(fp) + torso_pose = np.array(meta_torso['frames'][0]['transform_matrix']) + return poses, auds, bc_img, [H, W, focal, cx, cy], aud_ids, torso_pose diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/__init__.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/__init__.py new file mode 100644 index 00000000..bf675ce3 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/__init__.py @@ -0,0 +1 @@ +from talkingface.model.nerf_based_talkingface.wav2lip import Wav2Lip, SyncNet_color \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/wav2lip.py b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/wav2lip.py new file mode 100644 index 00000000..ff3dfb46 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/model/nerf_based_talkingface/wav2lip.py @@ -0,0 +1,358 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.nn.init import xavier_normal_, constant_ +from tqdm import tqdm +from os import listdir, path +import numpy as np +import os, subprocess +from glob import glob +import cv2 + +from talkingface.model.layers import Conv2d, Conv2dTranspose, nonorm_Conv2d +from talkingface.model.abstract_talkingface import AbstractTalkingFace +from talkingface.data.dataprocess.wav2lip_process import Wav2LipPreprocessForInference, Wav2LipAudio +from talkingface.utils import ensure_dir + +class SyncNet_color(nn.Module): + def __init__(self): + super(SyncNet_color, self).__init__() + + self.face_encoder = nn.Sequential( + Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3), + + Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=2, padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=2, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(512, 512, kernel_size=3, stride=2, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T) + face_embedding = self.face_encoder(face_sequences) + audio_embedding = self.audio_encoder(audio_sequences) + + audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) + face_embedding = face_embedding.view(face_embedding.size(0), -1) + + audio_embedding = F.normalize(audio_embedding, p=2, dim=1) + face_embedding = F.normalize(face_embedding, p=2, dim=1) + + + return audio_embedding, face_embedding + + + + + + + +class Wav2Lip(AbstractTalkingFace): + """wav2lip is a GAN-based model that predict the final with audio and image""" + def __init__(self, config): + super(Wav2Lip, self).__init__() + + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 + + nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), + + nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 + Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) + + self.face_decoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3 + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 + + nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12 + + nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24 + + nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48 + + nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 + + self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), + nn.Sigmoid()) + self.config = config + self.l1loss = nn.L1Loss() + self.bceloss = nn.BCELoss() + def forward(self, audio_sequences, face_sequences): + # audio_sequences = (B, T, 1, 80, 16) + B = audio_sequences.size(0) + + input_dim_size = len(face_sequences.size()) + if input_dim_size > 4: + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) + + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + + feats = [] + x = face_sequences + for f in self.face_encoder_blocks: + x = f(x) + feats.append(x) + + x = audio_embedding + for f in self.face_decoder_blocks: + x = f(x) + try: + x = torch.cat((x, feats[-1]), dim=1) + except Exception as e: + print(x.size()) + print(feats[-1].size()) + raise e + + feats.pop() + + x = self.output_block(x) + + if input_dim_size > 4: + x = torch.split(x, B, dim=0) # [(B, C, H, W)] + outputs = torch.stack(x, dim=2) # (B, C, T, H, W) + + else: + outputs = x + + return outputs + + def predict(self, audio_sequences, face_sequences): + return self.forward(audio_sequences, face_sequences) + + def calculate_loss(self, interaction, valid=False): + r"""Calculate the training loss for a batch data. + + Args: + interaction (Interaction): Interaction class of the batch. + + Returns: + torch.Tensor: Training loss, shape: [] + """ + indiv_mels = interaction['indiv_mels'].to(self.config['device']) + input_frames = interaction['input_frames'].to(self.config['device']) + mel = interaction['mels'].to(self.config['device']) + gt = interaction['gt'].to(self.config['device']) + g_frames = self.forward(indiv_mels, input_frames) + l1loss = self.l1loss(g_frames, gt) + if self.config['syncnet_wt'] > 0 or valid: + sync_loss = self.syncnet_loss(mel, g_frames) + else: + sync_loss = 0 + + loss = self.config['syncnet_wt'] * sync_loss + (1- self.config['syncnet_wt']) * l1loss + return {"loss":loss, "l1loss":l1loss, "sync_loss":sync_loss} + + def syncnet_loss(self, mel, g_frames): + syncnet = self.load_syncnet() + syncnet.eval() + g = g_frames[:, :, :, g_frames.size(3)//2:] + g = torch.cat([g[:, :, i] for i in range(self.config['syncnet_T'])], dim=1) + # B, 3 * T, H//2, W + a, v = syncnet(mel, g) + y = torch.ones(g.size(0), 1).float().to(self.config['device']) + return self.cosine_loss(a, v, y) + + def cosine_loss(self, a, v, y): + d = nn.functional.cosine_similarity(a, v) + loss = self.bceloss(d.unsqueeze(1), y) + + return loss + + def load_syncnet(self): + syncnet = SyncNet_color().to(self.config['device']) + for p in syncnet.parameters(): + p.requires_grad = False + checkpoint = torch.load(self.config["syncnet_checkpoint_path"]) + s = checkpoint["state_dict"] + new_s = {} + for k, v in s.items(): + new_s[k.replace('module.', '')] = v + syncnet.load_state_dict(new_s) + return syncnet + + def generate_batch(self): + audio_processor = Wav2LipAudio(self.config) + video_processor = Wav2LipPreprocessForInference(self.config) + + with open(self.config['test_filelist'], 'r') as filelist: + lines = filelist.readlines() + + file_dict = {'generated_video': [], 'real_video': []} + for idx, line in enumerate(tqdm(lines, desc='generate video')): + file_src = line.split()[0] + + audio_src = os.path.join(self.config['data_root'], file_src) + '.mp4' + video = os.path.join(self.config['data_root'], file_src) + '.mp4' + + + ensure_dir(os.path.join(self.config['temp_dir'])) + + command = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}'.format(audio_src, os.path.join(self.config['temp_dir'], 'temp')+'.wav') + subprocess.call(command, shell=True) + + temp_audio = os.path.join(self.config['temp_dir'], 'temp')+'.wav' + wav = audio_processor.load_wav(temp_audio, 16000) + mel = audio_processor.melspectrogram(wav) + + if np.isnan(mel.reshape(-1)).sum() > 0: + continue + + mel_idx_multiplier = 80./self.config['fps'] + mel_chunks = [] + i = 0 + while 1: + start_idx = int(i * mel_idx_multiplier) + if start_idx + self.config['mel_step_size'] > len(mel[0]): + break + mel_chunks.append(mel[:, start_idx : start_idx + self.config['mel_step_size']]) + i += 1 + + video_stream = cv2.VideoCapture(video) + full_frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading or len(full_frames) > len(mel_chunks): + video_stream.release() + break + full_frames.append(frame) + + if len(full_frames) < len(mel_chunks): + continue + + full_frames = full_frames[:len(mel_chunks)] + + try: + face_det_results = video_processor.face_detect(full_frames.copy()) + except ValueError as e: + continue + + batch_size = self.config['wav2lip_batch_size'] + gen = video_processor.datagen(full_frames.copy(), face_det_results, mel_chunks) + + for i, (img_batch, mel_batch, frames, coords) in enumerate(gen): + if i == 0: + frame_h, frame_w = full_frames[0].shape[:-1] + output_video_path = os.path.join(self.config['temp_dir'], 'temp')+'.mp4' + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 或者尝试 'avc1' + out = cv2.VideoWriter(output_video_path, fourcc, 25, (frame_w, frame_h)) + + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(self.config['device']) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(self.config['device']) + + with torch.no_grad(): + pred = self.predict(mel_batch, img_batch) + + + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + + for pl, f, c in zip(pred, frames, coords): + y1, y2, x1, x2 = c + pl = cv2.resize(pl.astype(np.uint8), (x2 - x1, y2 - y1)) + f[y1:y2, x1:x2] = pl + out.write(f) + + out.release() + + vid = os.path.join(self.config['temp_dir'], file_src) + '.mp4' + vid_directory = os.path.dirname(vid) + if not os.path.exists(vid_directory): + os.makedirs(vid_directory) + + command = 'ffmpeg -loglevel panic -y -i {} -i {} -strict -2 -q:v 1 {}'.format(temp_audio, + os.path.join(self.config['temp_dir'], 'temp')+'.mp4', vid) + process_status = subprocess.call(command, shell=True) + if process_status == 0: + file_dict['generated_video'].append(vid) + file_dict['real_video'].append(video) + else: + continue + return file_dict + diff --git a/talkingface-toolkit-main/talkingface/properties/dataset/lrs2.yaml b/talkingface-toolkit-main/talkingface/properties/dataset/lrs2.yaml new file mode 100644 index 00000000..3afa074f --- /dev/null +++ b/talkingface-toolkit-main/talkingface/properties/dataset/lrs2.yaml @@ -0,0 +1,10 @@ +train_filelist: 'dataset/lrs2/filelist/train.txt' # 当前数据集的数据划分文件 train +test_filelist: 'dataset/lrs2/filelist/test.txt' # 当前数据集的数据划分文件 test +val_filelist: 'dataset/lrs2/filelist/val.txt' # 当前数据集的数据划分文件 val + +data_root: 'dataset/lrs2/data/main' # 当前数据集的数据根目录 +preprocessed_root: 'dataset/lrs2/preprocessed_data' # 当前数据集的预处理数据根目录 + +need_preprocess: True # 数据集是否需要预处理,如抽帧、抽音频等 + +preprocess_batch_size: 32 \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master.yaml new file mode 100644 index 00000000..ef3b2b78 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master.yaml @@ -0,0 +1,37 @@ +model_params: + APC: + ckp_path: './checkpoints/live_speech_portraits/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './checkpoints/live_speech_portraits/McStay/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './checkpoints/live_speech_portraits/McStay/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './checkpoints/live_speech_portraits/McStay/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + +#我自己加的 +checkpoint_sub_dir: "/live_speech_portraits" # 和overall.yaml里checkpoint_dir拼起来作为最终目录 +temp_sub_dir: "/live_speech_portraits" # 和overall.yaml里temp_dir拼起来作为最终目录 +driving_audio_path: './checkpoints/live_speech_portraits/Input/00083.wav' #驱动音频路径 +save_intermediates: 0 #是否存储中间文件 + +dataset_params: + root: './checkpoints/live_speech_portraits/McStay' + fit_data_path: './checkpoints/live_speech_portraits/McStay/3d_fit_data.npz' + pts3d_path: './checkpoints/live_speech_portraits/McStay/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/May.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/May.yaml new file mode 100644 index 00000000..d28c4e76 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/May.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/May/checkpoints/Audio2Feature.pkl' + smooth: 1.5 + AMP: ['XYZ', 2, 2, 2] # method, x, y, z + Headpose: + ckp_path: './data/May/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [1, 0.5] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/May/checkpoints/Feature2Face.pkl' + size: 'large' + save_input: 1 + + +dataset_params: + root: './data/May/' + fit_data_path: './data/May/3d_fit_data.npz' + pts3d_path: './data/May/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/McStay.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/McStay.yaml new file mode 100644 index 00000000..25d9db17 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/McStay.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/McStay/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/McStay/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/McStay/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/McStay/' + fit_data_path: './data/McStay/3d_fit_data.npz' + pts3d_path: './data/McStay/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Nadella.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Nadella.yaml new file mode 100644 index 00000000..66f33573 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Nadella.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/Nadella/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/Nadella/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [0.5, 0.5] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/Nadella/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/Nadella/' + fit_data_path: './data/Nadella/3d_fit_data.npz' + pts3d_path: './data/Nadella/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama1.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama1.yaml new file mode 100644 index 00000000..ce414876 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama1.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/Obama1/checkpoints/Audio2Feature.pkl' + smooth: 1 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/Obama1/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [2, 8] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/Obama1/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/Obama1/' + fit_data_path: './data/Obama1/3d_fit_data.npz' + pts3d_path: './data/Obama1/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama2.yaml b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama2.yaml new file mode 100644 index 00000000..6d543151 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/properties/model/AD_NeRF_master/Obama2.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/Obama2/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/Obama2/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [3, 10] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/Obama2/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/Obama2/' + fit_data_path: './data/Obama2/3d_fit_data.npz' + pts3d_path: './data/Obama2/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface-toolkit-main/talkingface/properties/model/Wav2Lip.yaml b/talkingface-toolkit-main/talkingface/properties/model/Wav2Lip.yaml new file mode 100644 index 00000000..50b044cc --- /dev/null +++ b/talkingface-toolkit-main/talkingface/properties/model/Wav2Lip.yaml @@ -0,0 +1,55 @@ +# Syncnet +syncnet_wt: 0.03 # (int) is initially zero, will be set automatically to 0.03 later.Leads to faster convergence. +syncnet_batch_size: 64 # (int) batch_size for syncnet train +syncnet_lr: 0.0001 #(float) learning rate for syncnet train +syncnet_eval_interval: 10000 +syncnet_checkpoint_interval: 10000 +syncnet_T: 5 +syncnet_mel_step_size: 16 +syncnet_checkpoint_path: "checkpoints/wav2lip/lipsync_expert.pth" + +# Data preprocessing for Wav2lip +num_mels: 80 +rescale: True +rescaling_max: 0.9 +use_lws: False +n_fft: 800 +hop_size: 200 +win_size: 800 +sample_rate: 16000 +frame_shift_ms: None +signal_normalization: True +allow_clipping_in_normalization: True +symmetric_mels: True +max_abs_value: 4 +preemphasize: True +preemphasis: 0.97 +min_level_db: -100 +ref_level_db: 20 +fmin: 55 +fmax: 7600 +img_size: 96 +fps: 25 +mel_step_size: 16 + +batch_size: 16 +ngpu: 1 + + +# Train +checkpoint_sub_dir: "/wav2lip" # 和overall.yaml里checkpoint_dir拼起来作为最终目录 + +temp_sub_dir: "/wav2lip" # 和overall.yaml里temp_dir拼起来作为最终目录 + + +# Inference +pads: [0, 10, 0, 0] +static: False +face_det_batch_size: 16 +resize_factor: 1 +crop: [0, -1, 0, -1] +box: [-1, -1, -1, -1] +rotate: False +nosmooth: False +wav2lip_batch_size: 128 +vshift: 15 \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/properties/overall.yaml b/talkingface-toolkit-main/talkingface/properties/overall.yaml new file mode 100644 index 00000000..81ac51ae --- /dev/null +++ b/talkingface-toolkit-main/talkingface/properties/overall.yaml @@ -0,0 +1,31 @@ +# Enviroment Settings +gpu_id: '3, 4, 5' # (str) The id of GPU device(s). +worker: 0 # (int) The number of workers processing the data. +use_gpu: True # (bool) Whether or not to use GPU. +seed: 2023 # (int) Random seed. +checkpoint_dir: 'saved' # (str) The path to save checkpoint file. +show_progress: True # (bool) Whether or not to show the progress bar of every epoch. +log_wandb: False # (bool) Whether or not to use Weights & Biases(W&B). +shuffle: True # (bool) Whether or not to shuffle the training data before each epoch. +device: 'cuda' +reproducibility: True # (bool) Whether or not to make results reproducible. + +# Training Settings +epochs: 300 # (int) The number of training epochs. +train_batch_size: 2048 # (int) The training batch size. +learner: adam # (str) The name of used optimizer. +learning_rate: 0.0001 # (float) Learning rate. +eval_step: 1 # (int) The number of training epochs before an evaluation on the valid dataset. +stopping_step: 10 # (int) The threshold for validation-based early stopping. +weight_decay: 0.0 # (float) The weight decay value (L2 penalty) for optimizers. +saved: True +resume: True +train: True + +# Evaluation Settings +metrics: ["LSE", "SSIM"] +evaluate_batch_size: 50 # (int) The evaluation batch size. +lse_checkpoint_path: 'checkpoints/LSE/syncnet_v2.model' +temp_dir: 'results/temp' +lse_reference_dir: 'lse' +valid_metric_bigger: False # (bool) Whether to take a bigger valid metric value as a better result. \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/quick_start/__init__.py b/talkingface-toolkit-main/talkingface/quick_start/__init__.py new file mode 100644 index 00000000..7dd17091 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/quick_start/__init__.py @@ -0,0 +1,5 @@ +from talkingface.quick_start.quick_start import ( + run, + run_talkingface, +) +from talkingface.quick_start.meta_portrait_base_main import * \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/quick_start/quick_start.py b/talkingface-toolkit-main/talkingface/quick_start/quick_start.py new file mode 100644 index 00000000..3ff2e889 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/quick_start/quick_start.py @@ -0,0 +1,105 @@ +import logging +import sys +import torch.distributed as dist +from collections.abc import MutableMapping +from logging import getLogger +import os +from torch.utils import data as data_utils +from ray import tune + +from talkingface.config import Config + +from talkingface.utils import ( + init_logger, + get_model, + get_trainer, + init_seed, + set_color, + get_flops, + get_environment, + get_preprocess, + create_dataset +) + +def run( + model, + dataset, + config_file_list=None, + config_dict=None, + saved=True, + evaluate_model_file=None +): + res = run_talkingface( + model=model, + dataset=dataset, + config_file_list=config_file_list, + config_dict=config_dict, + saved=saved, + evaluate_model_file=evaluate_model_file, + ) + return res + +def run_talkingface( + model=None, + dataset=None, + config_file_list=None, + config_dict=None, + saved=True, + queue=None, + evaluate_model_file=None +): + """A fast running api, which include the complete process of training and testing a model on a specified dataset + Args: + model (str, optional): Model name. Defaults to ``None``. + dataset (str, optional): Dataset name. Defaults to ``None``. + config_file_list (list, optional): Config files used to modify experiment parameters. Defaults to ``None``. + config_dict (dict, optional): Parameters dictionary used to modify experiment parameters. Defaults to ``None``. + saved (bool, optional): Whether to save the model. Defaults to ``True``. + queue (torch.multiprocessing.Queue, optional): The queue used to pass the result to the main process. Defaults to ``None``. + """ + + config = Config( + model=model, + dataset=dataset, + config_file_list=config_file_list, + config_dict=config_dict, + ) + init_seed(config["seed"], config["reproducibility"]) + init_logger(config) + logger = getLogger() + logger.info(sys.argv) + logger.info(config) + + #data processing + # print(not (os.listdir(config['preprocessed_root']))) + if config['need_preprocess'] and (not (os.path.exists(config['preprocessed_root'])) or not (os.listdir(config['preprocessed_root']))): + get_preprocess(config['dataset'])(config).run() + + train_dataset, val_dataset = create_dataset(config) + train_data_loader = data_utils.DataLoader( + train_dataset, batch_size=config["batch_size"], shuffle=True + ) + val_data_loader = data_utils.DataLoader( + val_dataset, batch_size=config["batch_size"], shuffle=False + ) + + # load model + model = get_model(config["model"])(config).to(config["device"]) + logger.info(model) + + trainer = get_trainer(config["model"])(config, model) + + # model training + if config['train']: + trainer.fit(train_data_loader, val_data_loader, saved=saved, show_progress=config["show_progress"]) + # print(1) + + if not config['train'] and evaluate_model_file is None: + print("error: no model file to evaluate without training") + return + # model evaluating + trainer.evaluate(model_file = evaluate_model_file) + + + + diff --git a/talkingface-toolkit-main/talkingface/trainer/AD_NeRF_masterTrainer.py b/talkingface-toolkit-main/talkingface/trainer/AD_NeRF_masterTrainer.py new file mode 100644 index 00000000..51385639 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/trainer/AD_NeRF_masterTrainer.py @@ -0,0 +1,965 @@ +from load_audface import load_audface_data +import os +import sys +import numpy as np +import imageio +import json +import random +import time +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm, trange +from natsort import natsorted +from run_nerf_helpers import * + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +np.random.seed(0) +DEBUG = False + + +def batchify(fn, chunk): + """Constructs a version of 'fn' that applies to smaller batches. + """ + if chunk is None: + return fn + + def ret(inputs): + return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0) + return ret + + +def run_network(inputs, viewdirs, aud_para, fn, embed_fn, embeddirs_fn, netchunk=1024*64): + """Prepares inputs and applies network 'fn'. + """ + inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]]) + embedded = embed_fn(inputs_flat) + aud = aud_para.unsqueeze(0).expand(inputs_flat.shape[0], -1) + embedded = torch.cat((embedded, aud), -1) + if viewdirs is not None: + input_dirs = viewdirs[:, None].expand(inputs.shape) + input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]]) + embedded_dirs = embeddirs_fn(input_dirs_flat) + embedded = torch.cat([embedded, embedded_dirs], -1) + + outputs_flat = batchify(fn, netchunk)(embedded) + outputs = torch.reshape(outputs_flat, list( + inputs.shape[:-1]) + [outputs_flat.shape[-1]]) + return outputs + + +def batchify_rays(rays_flat, bc_rgb, aud_para, chunk=1024*32, **kwargs): + """Render rays in smaller minibatches to avoid OOM. + """ + all_ret = {} + for i in range(0, rays_flat.shape[0], chunk): + ret = render_rays(rays_flat[i:i+chunk], bc_rgb[i:i+chunk], + aud_para, **kwargs) + for k in ret: + if k not in all_ret: + all_ret[k] = [] + all_ret[k].append(ret[k]) + + all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret} + return all_ret + + +def render_dynamic_face(H, W, focal, cx, cy, chunk=1024*32, rays=None, bc_rgb=None, aud_para=None, + c2w=None, ndc=True, near=0., far=1., + use_viewdirs=False, c2w_staticcam=None, + **kwargs): + if c2w is not None: + # special case to render full image + rays_o, rays_d = get_rays(H, W, focal, c2w, cx, cy) + bc_rgb = bc_rgb.reshape(-1, 3) + else: + # use provided ray batch + rays_o, rays_d = rays + + if use_viewdirs: + # provide ray directions as input + viewdirs = rays_d + if c2w_staticcam is not None: + # special case to visualize effect of viewdirs + rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam, cx, cy) + viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) + viewdirs = torch.reshape(viewdirs, [-1, 3]).float() + + sh = rays_d.shape # [..., 3] + if ndc: + # for forward facing scenes + rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d) + + # Create ray batch + rays_o = torch.reshape(rays_o, [-1, 3]).float() + rays_d = torch.reshape(rays_d, [-1, 3]).float() + + near, far = near * \ + torch.ones_like(rays_d[..., :1]), far * \ + torch.ones_like(rays_d[..., :1]) + rays = torch.cat([rays_o, rays_d, near, far], -1) + if use_viewdirs: + rays = torch.cat([rays, viewdirs], -1) + + # Render and reshape + all_ret = batchify_rays(rays, bc_rgb, aud_para, chunk, **kwargs) + for k in all_ret: + k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) + all_ret[k] = torch.reshape(all_ret[k], k_sh) + + k_extract = ['rgb_map', 'disp_map', 'acc_map', 'last_weight'] + ret_list = [all_ret[k] for k in k_extract] + ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract} + return ret_list + [ret_dict] + + +def render(H, W, focal, cx, cy, chunk=1024*32, rays=None, c2w=None, ndc=True, + near=0., far=1., + use_viewdirs=False, c2w_staticcam=None, + **kwargs): + """Render rays + Args: + H: int. Height of image in pixels. + W: int. Width of image in pixels. + focal: float. Focal length of pinhole camera. + chunk: int. Maximum number of rays to process simultaneously. Used to + control maximum memory usage. Does not affect final results. + rays: array of shape [2, batch_size, 3]. Ray origin and direction for + each example in batch. + c2w: array of shape [3, 4]. Camera-to-world transformation matrix. + ndc: bool. If True, represent ray origin, direction in NDC coordinates. + near: float or array of shape [batch_size]. Nearest distance for a ray. + far: float or array of shape [batch_size]. Farthest distance for a ray. + use_viewdirs: bool. If True, use viewing direction of a point in space in model. + c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for + camera while using other c2w argument for viewing directions. + Returns: + rgb_map: [batch_size, 3]. Predicted RGB values for rays. + disp_map: [batch_size]. Disparity map. Inverse of depth. + acc_map: [batch_size]. Accumulated opacity (alpha) along a ray. + extras: dict with everything returned by render_rays(). + """ + if c2w is not None: + # special case to render full image + rays_o, rays_d = get_rays(H, W, focal, c2w, cx, cy) + else: + # use provided ray batch + rays_o, rays_d = rays + + if use_viewdirs: + # provide ray directions as input + viewdirs = rays_d + if c2w_staticcam is not None: + # special case to visualize effect of viewdirs + rays_o, rays_d = get_rays(H, W, focal, c2w_staticcam, cx, cy) + viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True) + viewdirs = torch.reshape(viewdirs, [-1, 3]).float() + + sh = rays_d.shape # [..., 3] + if ndc: + # for forward facing scenes + rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d) + + # Create ray batch + rays_o = torch.reshape(rays_o, [-1, 3]).float() + rays_d = torch.reshape(rays_d, [-1, 3]).float() + + near, far = near * \ + torch.ones_like(rays_d[..., :1]), far * \ + torch.ones_like(rays_d[..., :1]) + rays = torch.cat([rays_o, rays_d, near, far], -1) + if use_viewdirs: + rays = torch.cat([rays, viewdirs], -1) + + # Render and reshape + all_ret = batchify_rays(rays, chunk, **kwargs) + for k in all_ret: + k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:]) + all_ret[k] = torch.reshape(all_ret[k], k_sh) + + k_extract = ['rgb_map', 'disp_map', 'acc_map'] + ret_list = [all_ret[k] for k in k_extract] + ret_dict = {k: all_ret[k] for k in all_ret if k not in k_extract} + return ret_list + [ret_dict] + + +def render_path(render_poses, aud_paras, bc_img, hwfcxy, + chunk, render_kwargs, gt_imgs=None, savedir=None, render_factor=0): + + H, W, focal, cx, cy = hwfcxy + + if render_factor != 0: + # Render downsampled for speed + H = H//render_factor + W = W//render_factor + focal = focal/render_factor + + rgbs = [] + disps = [] + last_weights = [] + + t = time.time() + for i, c2w in enumerate(tqdm(render_poses)): + print(i, time.time() - t) + t = time.time() + rgb, disp, acc, last_weight, _ = render_dynamic_face( + H, W, focal, cx, cy, chunk=chunk, c2w=c2w[:3, + :4], aud_para=aud_paras[i], bc_rgb=bc_img, + **render_kwargs) + rgbs.append(rgb.cpu().numpy()) + disps.append(disp.cpu().numpy()) + last_weights.append(last_weight.cpu().numpy()) + if i == 0: + print(rgb.shape, disp.shape) + + """ + if gt_imgs is not None and render_factor==0: + p = -10. * np.log10(np.mean(np.square(rgb.cpu().numpy() - gt_imgs[i]))) + print(p) + """ + + if savedir is not None: + rgb8 = to8b(rgbs[-1]) + filename = os.path.join(savedir, '{:03d}.png'.format(i)) + imageio.imwrite(filename, rgb8) + + rgbs = np.stack(rgbs, 0) + disps = np.stack(disps, 0) + last_weights = np.stack(last_weights, 0) + + return rgbs, disps, last_weights + + +def create_nerf(args): + """Instantiate NeRF's MLP model. + """ + embed_fn, input_ch = get_embedder(args.multires, args.i_embed) + + input_ch_views = 0 + embeddirs_fn = None + if args.use_viewdirs: + embeddirs_fn, input_ch_views = get_embedder( + args.multires_views, args.i_embed) + output_ch = 5 if args.N_importance > 0 else 4 + skips = [4] + model = FaceNeRF(D=args.netdepth, W=args.netwidth, + input_ch=input_ch, dim_aud=args.dim_aud, + output_ch=output_ch, skips=skips, + input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) + grad_vars = list(model.parameters()) + + model_fine = None + if args.N_importance > 0: + model_fine = FaceNeRF(D=args.netdepth_fine, W=args.netwidth_fine, + input_ch=input_ch, dim_aud=args.dim_aud, + output_ch=output_ch, skips=skips, + input_ch_views=input_ch_views, use_viewdirs=args.use_viewdirs).to(device) + grad_vars += list(model_fine.parameters()) + + def network_query_fn(inputs, viewdirs, aud_para, network_fn): \ + return run_network(inputs, viewdirs, aud_para, network_fn, + embed_fn=embed_fn, embeddirs_fn=embeddirs_fn, netchunk=args.netchunk) + + # Create optimizer + optimizer = torch.optim.Adam( + params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)) + + start = 0 + basedir = args.basedir + expname = args.expname + + ########################## + + # Load checkpoints + if args.ft_path is not None and args.ft_path != 'None': + ckpts = [args.ft_path] + else: + ckpts = [os.path.join(basedir, expname, f) for f in natsorted( + os.listdir(os.path.join(basedir, expname))) if 'tar' in f] + + print('Found ckpts', ckpts) + learned_codes_dict = None + AudNet_state = None + AudAttNet_state = None + optimizer_aud_state = None + optimizer_audatt_state = None + if len(ckpts) > 0 and not args.no_reload: + ckpt_path = ckpts[-1] + print('Reloading from', ckpt_path) + ckpt = torch.load(ckpt_path) + + start = ckpt['global_step'] + optimizer.load_state_dict(ckpt['optimizer_state_dict']) + AudNet_state = ckpt['network_audnet_state_dict'] + optimizer_aud_state = ckpt['optimizer_aud_state_dict'] + + # Load model + model.load_state_dict(ckpt['network_fn_state_dict']) + if model_fine is not None: + model_fine.load_state_dict(ckpt['network_fine_state_dict']) + if 'network_audattnet_state_dict' in ckpt: + AudAttNet_state = ckpt['network_audattnet_state_dict'] + if 'optimize_audatt_state_dict' in ckpt: + optimizer_audatt_state = ckpt['optimize_audatt_state_dict'] + + ########################## + + render_kwargs_train = { + 'network_query_fn': network_query_fn, + 'perturb': args.perturb, + 'N_importance': args.N_importance, + 'network_fine': model_fine, + 'N_samples': args.N_samples, + 'network_fn': model, + 'use_viewdirs': args.use_viewdirs, + 'white_bkgd': args.white_bkgd, + 'raw_noise_std': args.raw_noise_std, + } + + # NDC only good for LLFF-style forward facing data + if args.dataset_type != 'llff' or args.no_ndc: + print('Not ndc!') + render_kwargs_train['ndc'] = False + render_kwargs_train['lindisp'] = args.lindisp + + render_kwargs_test = { + k: render_kwargs_train[k] for k in render_kwargs_train} + render_kwargs_test['perturb'] = False + render_kwargs_test['raw_noise_std'] = 0. + + return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, learned_codes_dict, \ + AudNet_state, optimizer_aud_state, AudAttNet_state, optimizer_audatt_state + + +def raw2outputs(raw, z_vals, rays_d, bc_rgb, raw_noise_std=0, white_bkgd=False, pytest=False): + """Transforms model's predictions to semantically meaningful values. + Args: + raw: [num_rays, num_samples along ray, 4]. Prediction from model. + z_vals: [num_rays, num_samples along ray]. Integration time. + rays_d: [num_rays, 3]. Direction of each ray. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. + disp_map: [num_rays]. Disparity map. Inverse of depth map. + acc_map: [num_rays]. Sum of weights along each ray. + weights: [num_rays, num_samples]. Weights assigned to each sampled color. + depth_map: [num_rays]. Estimated distance to object. + """ + def raw2alpha(raw, dists, act_fn=F.relu): return 1. - \ + torch.exp(-(act_fn(raw)+1e-6)*dists) + + dists = z_vals[..., 1:] - z_vals[..., :-1] + dists = torch.cat([dists, torch.Tensor([1e10]).expand( + dists[..., :1].shape)], -1) # [N_rays, N_samples] + + dists = dists * torch.norm(rays_d[..., None, :], dim=-1) + + rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] + rgb = torch.cat((rgb[:, :-1, :], bc_rgb.unsqueeze(1)), dim=1) + noise = 0. + if raw_noise_std > 0.: + noise = torch.randn(raw[..., 3].shape) * raw_noise_std + + # Overwrite randomly sampled data if pytest + if pytest: + np.random.seed(0) + noise = np.random.rand(*list(raw[..., 3].shape)) * raw_noise_std + noise = torch.Tensor(noise) + + alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] + # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) + weights = alpha * \ + torch.cumprod( + torch.cat([torch.ones((alpha.shape[0], 1)), 1.-alpha + 1e-10], -1), -1)[:, :-1] + rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] + + depth_map = torch.sum(weights * z_vals, -1) + disp_map = 1./torch.max(1e-10 * torch.ones_like(depth_map), + depth_map / torch.sum(weights, -1)) + acc_map = torch.sum(weights, -1) + + if white_bkgd: + rgb_map = rgb_map + (1.-acc_map[..., None]) + + return rgb_map, disp_map, acc_map, weights, depth_map + + +def render_rays(ray_batch, + bc_rgb, + aud_para, + network_fn, + network_query_fn, + N_samples, + retraw=False, + lindisp=False, + perturb=0., + N_importance=0, + network_fine=None, + white_bkgd=False, + raw_noise_std=0., + verbose=False, + pytest=False): + """Volumetric rendering. + Args: + ray_batch: array of shape [batch_size, ...]. All information necessary + for sampling along a ray, including: ray origin, ray direction, min + dist, max dist, and unit-magnitude viewing direction. + network_fn: function. Model for predicting RGB and density at each point + in space. + network_query_fn: function used for passing queries to network_fn. + N_samples: int. Number of different times to sample along each ray. + retraw: bool. If True, include model's raw, unprocessed predictions. + lindisp: bool. If True, sample linearly in inverse depth rather than in depth. + perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified + random points in time. + N_importance: int. Number of additional times to sample along each ray. + These samples are only passed to network_fine. + network_fine: "fine" network with same spec as network_fn. + white_bkgd: bool. If True, assume a white background. + raw_noise_std: ... + verbose: bool. If True, print more debugging info. + Returns: + rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model. + disp_map: [num_rays]. Disparity map. 1 / depth. + acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model. + raw: [num_rays, num_samples, 4]. Raw predictions from model. + rgb0: See rgb_map. Output for coarse model. + disp0: See disp_map. Output for coarse model. + acc0: See acc_map. Output for coarse model. + z_std: [num_rays]. Standard deviation of distances along ray for each + sample. + """ + N_rays = ray_batch.shape[0] + rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each + viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None + bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2]) + near, far = bounds[..., 0], bounds[..., 1] # [-1,1] + + t_vals = torch.linspace(0., 1., steps=N_samples) + if not lindisp: + z_vals = near * (1.-t_vals) + far * (t_vals) + else: + z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals)) + + z_vals = z_vals.expand([N_rays, N_samples]) + + if perturb > 0.: + # get intervals between samples + mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) + upper = torch.cat([mids, z_vals[..., -1:]], -1) + lower = torch.cat([z_vals[..., :1], mids], -1) + # stratified samples in those intervals + t_rand = torch.rand(z_vals.shape) + + # Pytest, overwrite u with numpy's fixed random numbers + if pytest: + np.random.seed(0) + t_rand = np.random.rand(*list(z_vals.shape)) + t_rand = torch.Tensor(t_rand) + t_rand[..., -1] = 1.0 + z_vals = lower + (upper - lower) * t_rand + pts = rays_o[..., None, :] + rays_d[..., None, :] * \ + z_vals[..., :, None] # [N_rays, N_samples, 3] + raw = network_query_fn(pts, viewdirs, aud_para, network_fn) + rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs( + raw, z_vals, rays_d, bc_rgb, raw_noise_std, white_bkgd, pytest=pytest) + + if N_importance > 0: + + rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map + + z_vals_mid = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) + z_samples = sample_pdf( + z_vals_mid, weights[..., 1:-1], N_importance, det=(perturb == 0.), pytest=pytest) + z_samples = z_samples.detach() + + z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1) + pts = rays_o[..., None, :] + rays_d[..., None, :] * \ + z_vals[..., :, None] # [N_rays, N_samples + N_importance, 3] + + run_fn = network_fn if network_fine is None else network_fine + raw = network_query_fn(pts, viewdirs, aud_para, run_fn) + + rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs( + raw, z_vals, rays_d, bc_rgb, raw_noise_std, white_bkgd, pytest=pytest) + + ret = {'rgb_map': rgb_map, 'disp_map': disp_map, 'acc_map': acc_map} + if retraw: + ret['raw'] = raw + if N_importance > 0: + ret['rgb0'] = rgb_map_0 + ret['disp0'] = disp_map_0 + ret['acc0'] = acc_map_0 + ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays] + ret['last_weight'] = weights[..., -1] + + for k in ret: + if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG: + print(f"! [Numerical Error] {k} contains nan or inf.") + + return ret + + +def config_parser(): + + import configargparse + parser = configargparse.ArgumentParser() + parser.add_argument('--config', is_config_file=True, + help='config file path') + parser.add_argument("--expname", type=str, + help='experiment name') + parser.add_argument("--basedir", type=str, default='./logs/', + help='where to store ckpts and logs') + parser.add_argument("--datadir", type=str, default='./data/llff/fern', + help='input data directory') + + # training options + parser.add_argument("--netdepth", type=int, default=8, + help='layers in network') + parser.add_argument("--netwidth", type=int, default=256, + help='channels per layer') + parser.add_argument("--netdepth_fine", type=int, default=8, + help='layers in fine network') + parser.add_argument("--netwidth_fine", type=int, default=256, + help='channels per layer in fine network') + parser.add_argument("--N_rand", type=int, default=2048, + help='batch size (number of random rays per gradient step)') + parser.add_argument("--lrate", type=float, default=5e-4, + help='learning rate') + parser.add_argument("--lrate_decay", type=int, default=250, + help='exponential learning rate decay (in 1000 steps)') + parser.add_argument("--chunk", type=int, default=1024, + help='number of rays processed in parallel, decrease if running out of memory') + parser.add_argument("--netchunk", type=int, default=1024*64, + help='number of pts sent through network in parallel, decrease if running out of memory') + parser.add_argument("--no_batching", action='store_false', + help='only take random rays from 1 image at a time') + parser.add_argument("--no_reload", action='store_true', + help='do not reload weights from saved ckpt') + parser.add_argument("--ft_path", type=str, default=None, + help='specific weights npy file to reload for coarse network') + parser.add_argument("--N_iters", type=int, default=400000, + help='number of iterations') + + # rendering options + parser.add_argument("--N_samples", type=int, default=64, + help='number of coarse samples per ray') + parser.add_argument("--N_importance", type=int, default=128, + help='number of additional fine samples per ray') + parser.add_argument("--perturb", type=float, default=1., + help='set to 0. for no jitter, 1. for jitter') + parser.add_argument("--use_viewdirs", action='store_false', + help='use full 5D input instead of 3D') + parser.add_argument("--i_embed", type=int, default=0, + help='set 0 for default positional encoding, -1 for none') + parser.add_argument("--multires", type=int, default=10, + help='log2 of max freq for positional encoding (3D location)') + parser.add_argument("--multires_views", type=int, default=4, + help='log2 of max freq for positional encoding (2D direction)') + parser.add_argument("--raw_noise_std", type=float, default=0., + help='std dev of noise added to regularize sigma_a output, 1e0 recommended') + + parser.add_argument("--render_only", action='store_true', + help='do not optimize, reload weights and render out render_poses path') + parser.add_argument("--render_test", action='store_true', + help='render the test set instead of render_poses path') + parser.add_argument("--render_factor", type=int, default=0, + help='downsampling factor to speed up rendering, set 4 or 8 for fast preview') + + # training options + parser.add_argument("--precrop_iters", type=int, default=0, + help='number of steps to train on central crops') + parser.add_argument("--precrop_frac", type=float, + default=.5, help='fraction of img taken for central crops') + + # dataset options + parser.add_argument("--dataset_type", type=str, default='audface', + help='options: llff / blender / deepvoxels') + parser.add_argument("--testskip", type=int, default=8, + help='will load 1/N images from test/val sets, useful for large datasets like deepvoxels') + + # deepvoxels flags + parser.add_argument("--shape", type=str, default='greek', + help='options : armchair / cube / greek / vase') + + # blender flags + parser.add_argument("--white_bkgd", action='store_false', + help='set to render synthetic data on a white bkgd (always use for dvoxels)') + parser.add_argument("--half_res", action='store_true', + help='load blender synthetic data at 400x400 instead of 800x800') + + # face flags + parser.add_argument("--with_test", type=int, default=0, + help='whether to use test set') + parser.add_argument("--dim_aud", type=int, default=64, + help='dimension of audio features for NeRF') + parser.add_argument("--sample_rate", type=float, default=0.95, + help="sample rate in a bounding box") + parser.add_argument("--near", type=float, default=0.3, + help="near sampling plane") + parser.add_argument("--far", type=float, default=0.9, + help="far sampling plane") + parser.add_argument("--test_file", type=str, default='transforms_test.json', + help='test file') + parser.add_argument("--aud_file", type=str, default='aud.npy', + help='test audio deepspeech file') + parser.add_argument("--win_size", type=int, default=16, + help="windows size of audio feature") + parser.add_argument("--smo_size", type=int, default=8, + help="window size for smoothing audio features") + parser.add_argument('--nosmo_iters', type=int, default=200000, + help='number of iterations befor applying smoothing on audio features') + + # llff flags + parser.add_argument("--factor", type=int, default=8, + help='downsample factor for LLFF images') + parser.add_argument("--no_ndc", action='store_true', + help='do not use normalized device coordinates (set for non-forward facing scenes)') + parser.add_argument("--lindisp", action='store_true', + help='sampling linearly in disparity rather than depth') + parser.add_argument("--spherify", action='store_true', + help='set for spherical 360 scenes') + parser.add_argument("--llffhold", type=int, default=8, + help='will take every 1/N images as LLFF test set, paper uses 8') + + # logging/saving options + parser.add_argument("--i_print", type=int, default=100, + help='frequency of console printout and metric loggin') + parser.add_argument("--i_img", type=int, default=500, + help='frequency of tensorboard image logging') + parser.add_argument("--i_weights", type=int, default=10000, + help='frequency of weight ckpt saving') + parser.add_argument("--i_testset", type=int, default=10000, + help='frequency of testset saving') + parser.add_argument("--i_video", type=int, default=50000, + help='frequency of render_poses video saving') + + return parser + + +def train(): + + parser = config_parser() + args = parser.parse_args() + + # Load data + + if args.dataset_type == 'audface': + if args.with_test == 1: + poses, auds, bc_img, hwfcxy = \ + load_audface_data(args.datadir, args.testskip, + args.test_file, args.aud_file) + images = np.zeros(1) + else: + images, poses, auds, bc_img, hwfcxy, sample_rects, mouth_rects, i_split = load_audface_data( + args.datadir, args.testskip) + print('Loaded audface', images.shape, hwfcxy, args.datadir) + if args.with_test == 0: + i_train, i_val = i_split + + near = args.near + far = args.far + else: + print('Unknown dataset type', args.dataset_type, 'exiting') + return + + # Cast intrinsics to right types + H, W, focal, cx, cy = hwfcxy + H, W = int(H), int(W) + hwf = [H, W, focal] + hwfcxy = [H, W, focal, cx, cy] + + # if args.render_test: + # render_poses = np.array(poses[i_test]) + + # Create log dir and copy the config file + basedir = args.basedir + expname = args.expname + os.makedirs(os.path.join(basedir, expname), exist_ok=True) + f = os.path.join(basedir, expname, 'args.txt') + with open(f, 'w') as file: + for arg in sorted(vars(args)): + attr = getattr(args, arg) + file.write('{} = {}\n'.format(arg, attr)) + if args.config is not None: + f = os.path.join(basedir, expname, 'config.txt') + with open(f, 'w') as file: + file.write(open(args.config, 'r').read()) + + # Create nerf model + render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer, \ + learned_codes, AudNet_state, optimizer_aud_state, AudAttNet_state, optimizer_audatt_state \ + = create_nerf(args) + global_step = start + + AudNet = AudioNet(args.dim_aud, args.win_size).to(device) + AudAttNet = AudioAttNet().to(device) + optimizer_Aud = torch.optim.Adam( + params=list(AudNet.parameters()), lr=args.lrate, betas=(0.9, 0.999)) + optimizer_AudAtt = torch.optim.Adam( + params=list(AudAttNet.parameters()), lr=args.lrate, betas=(0.9, 0.999)) + + if AudNet_state is not None: + AudNet.load_state_dict(AudNet_state, strict=False) + if optimizer_aud_state is not None: + optimizer_Aud.load_state_dict(optimizer_aud_state) + if AudAttNet_state is not None: + AudAttNet.load_state_dict(AudAttNet_state, strict=False) + if optimizer_audatt_state is not None: + optimizer_AudAtt.load_state_dict(optimizer_audatt_state) + bds_dict = { + 'near': near, + 'far': far, + } + render_kwargs_train.update(bds_dict) + render_kwargs_test.update(bds_dict) + + # Move training data to GPU + bc_img = torch.Tensor(bc_img).to(device).float()/255.0 + poses = torch.Tensor(poses).to(device).float() + auds = torch.Tensor(auds).to(device).float() + + if args.render_only: + print('RENDER ONLY') + with torch.no_grad(): + # Default is smoother render_poses path + images = None + testsavedir = os.path.join(basedir, expname, 'renderonly_{}_{:06d}'.format( + 'test' if args.render_test else 'path', start)) + os.makedirs(testsavedir, exist_ok=True) + print('test poses shape', poses.shape) + auds_val = AudNet(auds) + rgbs, disp, last_weight = render_path(poses, auds_val, bc_img, hwfcxy, args.chunk, render_kwargs_test, + gt_imgs=images, savedir=testsavedir, render_factor=args.render_factor) + np.save(os.path.join(testsavedir, 'last_weight.npy'), last_weight) + print('Done rendering', testsavedir) + imageio.mimwrite(os.path.join( + testsavedir, 'video.mp4'), to8b(rgbs), fps=30, quality=8) + return + + num_frames = images.shape[0] + + + # Prepare raybatch tensor if batching random rays + N_rand = args.N_rand + print('N_rand', N_rand, 'no_batching', + args.no_batching, 'sample_rate', args.sample_rate) + use_batching = not args.no_batching + + if use_batching: + # For random ray batching + print('get rays') + rays = np.stack([get_rays_np(H, W, focal, p, cx, cy) + for p in poses[:, :3, :4]], 0) # [N, ro+rd, H, W, 3] + print('done, concats') + # [N, ro+rd+rgb, H, W, 3] + rays_rgb = np.concatenate([rays, images[:, None]], 1) + # [N, H, W, ro+rd+rgb, 3] + rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) + rays_rgb = np.stack([rays_rgb[i] + for i in i_train], 0) # train images only + # [(N-1)*H*W, ro+rd+rgb, 3] + rays_rgb = np.reshape(rays_rgb, [-1, 3, 3]) + rays_rgb = rays_rgb.astype(np.float32) + print('shuffle rays') + np.random.shuffle(rays_rgb) + + print('done') + i_batch = 0 + + if use_batching: + rays_rgb = torch.Tensor(rays_rgb).to(device) + + N_iters = args.N_iters + 1 + print('Begin') + print('TRAIN views are', i_train) + print('VAL views are', i_val) + + start = start + 1 + for i in trange(start, N_iters): + time0 = time.time() + + # Sample random ray batch + if use_batching: + # Random over all images + batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?] + batch = torch.transpose(batch, 0, 1) + batch_rays, target_s = batch[:2], batch[2] + + i_batch += N_rand + if i_batch >= rays_rgb.shape[0]: + print("Shuffle data after an epoch!") + rand_idx = torch.randperm(rays_rgb.shape[0]) + rays_rgb = rays_rgb[rand_idx] + i_batch = 0 + + else: + # Random from one image + img_i = np.random.choice(i_train) + target = torch.as_tensor(imageio.imread( + images[img_i])).to(device).float()/255.0 + pose = poses[img_i, :3, :4] + rect = sample_rects[img_i] + mouth_rect = mouth_rects[img_i] + aud = auds[img_i] + if global_step >= args.nosmo_iters: + smo_half_win = int(args.smo_size / 2) + left_i = img_i - smo_half_win + right_i = img_i + smo_half_win + pad_left, pad_right = 0, 0 + if left_i < 0: + pad_left = -left_i + left_i = 0 + if right_i > i_train.shape[0]: + pad_right = right_i-i_train.shape[0] + right_i = i_train.shape[0] + auds_win = auds[left_i:right_i] + if pad_left > 0: + auds_win = torch.cat( + (torch.zeros_like(auds_win)[:pad_left], auds_win), dim=0) + if pad_right > 0: + auds_win = torch.cat( + (auds_win, torch.zeros_like(auds_win)[:pad_right]), dim=0) + auds_win = AudNet(auds_win) + aud = auds_win[smo_half_win] + aud_smo = AudAttNet(auds_win) + else: + aud = AudNet(aud.unsqueeze(0)) + if N_rand is not None: + rays_o, rays_d = get_rays( + H, W, focal, torch.Tensor(pose), cx, cy) # (H, W, 3), (H, W, 3) + + if i < args.precrop_iters: + dH = int(H//2 * args.precrop_frac) + dW = int(W//2 * args.precrop_frac) + coords = torch.stack( + torch.meshgrid( + torch.linspace(H//2 - dH, H//2 + dH - 1, 2*dH), + torch.linspace(W//2 - dW, W//2 + dW - 1, 2*dW) + ), -1) + if i == start: + print( + f"[Config] Center cropping of size {2*dH} x {2*dW} is enabled until iter {args.precrop_iters}") + else: + coords = torch.stack(torch.meshgrid(torch.linspace( + 0, H-1, H), torch.linspace(0, W-1, W)), -1) # (H, W, 2) + + coords = torch.reshape(coords, [-1, 2]) # (H * W, 2) + if args.sample_rate > 0: + rect_inds = (coords[:, 0] >= rect[0]) & ( + coords[:, 0] <= rect[0] + rect[2]) & ( + coords[:, 1] >= rect[1]) & ( + coords[:, 1] <= rect[1] + rect[3]) + coords_rect = coords[rect_inds] + coords_norect = coords[~rect_inds] + rect_num = int(N_rand*args.sample_rate) + norect_num = N_rand - rect_num + select_inds_rect = np.random.choice( + coords_rect.shape[0], size=[rect_num], replace=False) # (N_rand,) + # (N_rand, 2) + select_coords_rect = coords_rect[select_inds_rect].long() + select_inds_norect = np.random.choice( + coords_norect.shape[0], size=[norect_num], replace=False) # (N_rand,) + # (N_rand, 2) + select_coords_norect = coords_norect[select_inds_norect].long( + ) + select_coords = torch.cat( + (select_coords_rect, select_coords_norect), dim=0) + else: + select_inds = np.random.choice( + coords.shape[0], size=[N_rand], replace=False) # (N_rand,) + select_coords = coords[select_inds].long() + + rays_o = rays_o[select_coords[:, 0], + select_coords[:, 1]] # (N_rand, 3) + rays_d = rays_d[select_coords[:, 0], + select_coords[:, 1]] # (N_rand, 3) + batch_rays = torch.stack([rays_o, rays_d], 0) + target_s = target[select_coords[:, 0], + select_coords[:, 1]] # (N_rand, 3) + bc_rgb = bc_img[select_coords[:, 0], + select_coords[:, 1]] + + ##### Core optimization loop ##### + if global_step >= args.nosmo_iters: + rgb, disp, acc, _, extras = render_dynamic_face(H, W, focal, cx, cy, chunk=args.chunk, rays=batch_rays, + aud_para=aud_smo, bc_rgb=bc_rgb, + verbose=i < 10, retraw=True, + **render_kwargs_train) + else: + rgb, disp, acc, _, extras = render_dynamic_face(H, W, focal, cx, cy, chunk=args.chunk, rays=batch_rays, + aud_para=aud, bc_rgb=bc_rgb, + verbose=i < 10, retraw=True, + **render_kwargs_train) + + optimizer.zero_grad() + optimizer_Aud.zero_grad() + optimizer_AudAtt.zero_grad() + img_loss = img2mse(rgb, target_s) + trans = extras['raw'][..., -1] + loss = img_loss + psnr = mse2psnr(img_loss) + + if 'rgb0' in extras: + img_loss0 = img2mse(extras['rgb0'], target_s) + loss = loss + img_loss0 + psnr0 = mse2psnr(img_loss0) + + loss.backward() + optimizer.step() + optimizer_Aud.step() + if global_step >= args.nosmo_iters: + optimizer_AudAtt.step() + # NOTE: IMPORTANT! + ### update learning rate ### + decay_rate = 0.1 + decay_steps = args.lrate_decay * 1000 + new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps)) + for param_group in optimizer.param_groups: + param_group['lr'] = new_lrate + + for param_group in optimizer_Aud.param_groups: + param_group['lr'] = new_lrate + + for param_group in optimizer_AudAtt.param_groups: + param_group['lr'] = new_lrate*5 + ################################ + + dt = time.time()-time0 + + # Rest is logging + if i % args.i_weights == 0: + path = os.path.join(basedir, expname, '{:06d}_head.tar'.format(i)) + torch.save({ + 'global_step': global_step, + 'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(), + 'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(), + 'network_audnet_state_dict': AudNet.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'optimizer_aud_state_dict': optimizer_Aud.state_dict(), + 'network_audattnet_state_dict': AudAttNet.state_dict(), + 'optimizer_audatt_state_dict': optimizer_AudAtt.state_dict(), + }, path) + print('Saved checkpoints at', path) + + if i % args.i_testset == 0 and i > 0: + testsavedir = os.path.join( + basedir, expname, 'testset_{:06d}'.format(i)) + os.makedirs(testsavedir, exist_ok=True) + print('test poses shape', poses[i_val].shape) + auds_val = AudNet(auds[i_val]) + with torch.no_grad(): + render_path(torch.Tensor(poses[i_val]).to( + device), auds_val, bc_img, hwfcxy, args.chunk, render_kwargs_test, gt_imgs=None, savedir=testsavedir) + print('Saved test set') + + if i % args.i_print == 0: + tqdm.write( + f"[TRAIN] Iter: {i} Loss: {loss.item()} PSNR: {psnr.item()}") + global_step += 1 + + +if __name__ == '__main__': + torch.set_default_tensor_type('torch.cuda.FloatTensor') + + train() diff --git a/talkingface-toolkit-main/talkingface/trainer/__init__.py b/talkingface-toolkit-main/talkingface/trainer/__init__.py new file mode 100644 index 00000000..219a0873 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/trainer/__init__.py @@ -0,0 +1,2 @@ +from talkingface.trainer.trainer import * +from talkingface.trainer.meta_portrait_base_train_ddp import * diff --git a/talkingface-toolkit-main/talkingface/trainer/trainer.py b/talkingface-toolkit-main/talkingface/trainer/trainer.py new file mode 100644 index 00000000..2c34717b --- /dev/null +++ b/talkingface-toolkit-main/talkingface/trainer/trainer.py @@ -0,0 +1,557 @@ +import os + +from logging import getLogger +from time import time +import dlib, json, subprocess +import torch.nn.functional as F +import glob +import numpy as np +import torch +import torch.optim as optim +from torch.nn.utils.clip_grad import clip_grad_norm_ +from tqdm import tqdm +import torch.cuda.amp as amp +from torch import nn +from pathlib import Path + +from talkingface.utils import( + ensure_dir, + get_local_time, + early_stopping, + calculate_valid_score, + dict2str, + get_tensorboard, + set_color, + get_gpu_usage, + WandbLogger +) +from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio +from talkingface.evaluator import Evaluator + + +class AbstractTrainer(object): + r"""Trainer Class is used to manage the training and evaluation processes of recommender system models. + AbstractTrainer is an abstract class in which the fit() and evaluate() method should be implemented according + to different training and evaluation strategies. + """ + + def __init__(self, config, model): + self.config = config + self.model = model + + def fit(self, train_data): + r"""Train the model based on the train data.""" + raise NotImplementedError("Method [next] should be implemented.") + + def evaluate(self, eval_data): + r"""Evaluate the model based on the eval data.""" + + raise NotImplementedError("Method [next] should be implemented.") + + +class Trainer(AbstractTrainer): + r"""The basic Trainer for basic training and evaluation strategies in talkingface systems. This class defines common + functions for training and evaluation processes of most recommender system models, including fit(), evaluate(), + resume_checkpoint() and some other features helpful for model training and evaluation. + + Generally speaking, this class can serve most recommender system models, If the training process of the model is to + simply optimize a single loss without involving any complex training strategies. + + Initializing the Trainer needs two parameters: `config` and `model`. `config` records the parameters information + for controlling training and evaluation, such as `learning_rate`, `epochs`, `eval_step` and so on. + `model` is the instantiated object of a Model Class. + + """ + def __init__(self, config, model): + super(Trainer, self).__init__(config, model) + self.logger = getLogger() + self.tensorboard = get_tensorboard(self.logger) + self.wandblogger = WandbLogger(config) + # self.enable_amp = config["enable_amp"] + # self.enable_scaler = torch.cuda.is_available() and config["enable_scaler"] + + # config for train + self.learner = config["learner"] + self.learning_rate = config["learning_rate"] + self.epochs = config["epochs"] + self.eval_step = min(config["eval_step"], self.epochs) + self.stopping_step = config["stopping_step"] + self.test_batch_size = config["eval_batch_size"] + self.gpu_available = torch.cuda.is_available() and config["use_gpu"] + self.device = config["device"] + self.checkpoint_dir = config["checkpoint_dir"] + ensure_dir(self.checkpoint_dir) + saved_model_file = "{}-{}.pth".format(self.config["model"], get_local_time()) + self.saved_model_file = os.path.join(self.checkpoint_dir, saved_model_file) + self.weight_decay = config["weight_decay"] + self.start_epoch = 0 + self.cur_step = 0 + self.train_loss_dict = dict() + self.optimizer = self._build_optimizer() + self.evaluator = Evaluator(config) + + self.valid_metric_bigger = config["valid_metric_bigger"] + self.best_valid_score = -np.inf if self.valid_metric_bigger else np.inf + self.best_valid_result = None + + def _build_optimizer(self, **kwargs): + params = kwargs.pop("params", self.model.parameters()) + learner = kwargs.pop("learner", self.learner) + learning_rate = kwargs.pop("learning_rate", self.learning_rate) + weight_decay = kwargs.pop("weight_decay", self.weight_decay) + if (self.config["reg_weight"] and weight_decay and weight_decay * self.config["reg_weight"] > 0): + self.logger.warning( + "The parameters [weight_decay] and [reg_weight] are specified simultaneously, " + "which may lead to double regularization." + ) + + if learner.lower() == "adam": + optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay) + elif learner.lower() == "adamw": + optimizer = optim.AdamW(params, lr=learning_rate, weight_decay=weight_decay) + elif learner.lower() == "sgd": + optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay) + elif learner.lower() == "adagrad": + optimizer = optim.Adagrad( + params, lr=learning_rate, weight_decay=weight_decay + ) + elif learner.lower() == "rmsprop": + optimizer = optim.RMSprop( + params, lr=learning_rate, weight_decay=weight_decay + ) + elif learner.lower() == "sparse_adam": + optimizer = optim.SparseAdam(params, lr=learning_rate) + if weight_decay > 0: + self.logger.warning( + "Sparse Adam cannot argument received argument [{weight_decay}]" + ) + else: + self.logger.warning( + "Received unrecognized optimizer, set default Adam optimizer" + ) + optimizer = optim.Adam(params, lr=learning_rate) + return optimizer + + def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False): + r"""Train the model in an epoch + + Args: + train_data (DataLoader): The train data. + epoch_idx (int): The current epoch id. + loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be + :attr:`self.model.calculate_loss`. Defaults to ``None``. + show_progress (bool): Show the progress of training epoch. Defaults to ``False``. + + Returns: + the averaged loss of this epoch + """ + self.model.train() + loss_func = loss_func or self.model.calculate_loss + total_loss_dict = {} + step = 0 + iter_data = ( + tqdm( + train_data, + total=len(train_data), + ncols=None, + ) + if show_progress + else train_data + ) + + for batch_idx, interaction in enumerate(iter_data): + self.optimizer.zero_grad() + step += 1 + losses_dict = loss_func(interaction) + loss = losses_dict["loss"] + + for key, value in losses_dict.items(): + if key in total_loss_dict: + if not torch.is_tensor(value): + total_loss_dict[key] += value + # 如果键已经在总和字典中,累加当前值 + else: + losses_dict[key] = value.item() + total_loss_dict[key] += value.item() + else: + if not torch.is_tensor(value): + total_loss_dict[key] = value + # 否则,将当前值添加到字典中 + else: + losses_dict[key] = value.item() + total_loss_dict[key] = value.item() + iter_data.set_description(set_color(f"train {epoch_idx} {losses_dict}", "pink")) + + self._check_nan(loss) + loss.backward() + self.optimizer.step() + average_loss_dict = {} + for key, value in total_loss_dict.items(): + average_loss_dict[key] = value/step + + return average_loss_dict + + + + def _valid_epoch(self, valid_data, show_progress=False): + r"""Valid the model with valid data. Different from the evaluate, this is use for training. + + Args: + valid_data (DataLoader): the valid data. + show_progress (bool): Show the progress of evaluate epoch. Defaults to ``False``. + + Returns: + loss + """ + print('Valid for {} steps'.format(self.eval_steps)) + self.model.eval() + total_loss_dict = {} + iter_data = ( + tqdm(valid_data, + total=len(valid_data), + ncols=None, + ) + if show_progress + else valid_data + ) + step = 0 + for batch_idx, batched_data in enumerate(iter_data): + step += 1 + batched_data.to(self.device) + losses_dict = self.model.calculate_loss(batched_data, valid=True) + for key, value in losses_dict.items(): + if key in total_loss_dict: + if not torch.is_tensor(value): + total_loss_dict[key] += value + # 如果键已经在总和字典中,累加当前值 + else: + losses_dict[key] = value.item() + total_loss_dict[key] += value.item() + else: + if not torch.is_tensor(value): + total_loss_dict[key] = value + # 否则,将当前值添加到字典中 + else: + losses_dict[key] = value.item() + total_loss_dict[key] = value.item() + iter_data.set_description(set_color(f"Valid {losses_dict}", "pink")) + average_loss_dict = {} + for key, value in total_loss_dict.items(): + average_loss_dict[key] = value/step + + return average_loss_dict + + + + + def _save_checkpoint(self, epoch, verbose=True, **kwargs): + r"""Store the model parameters information and training information. + + Args: + epoch (int): the current epoch id + + """ + saved_model_file = kwargs.pop("saved_model_file", self.saved_model_file) + state = { + "config": self.config, + "epoch": epoch, + "cur_step": self.cur_step, + "best_valid_score": self.best_valid_score, + "state_dict": self.model.state_dict(), + "other_parameter": self.model.other_parameter(), + "optimizer": self.optimizer.state_dict(), + } + torch.save(state, saved_model_file, pickle_protocol=4) + if verbose: + self.logger.info( + set_color("Saving current", "blue") + f": {saved_model_file}" + ) + + def resume_checkpoint(self, resume_file): + r"""Load the model parameters information and training information. + + Args: + resume_file (file): the checkpoint file + + """ + resume_file = str(resume_file) + self.saved_model_file = resume_file + checkpoint = torch.load(resume_file, map_location=self.device) + self.start_epoch = checkpoint["epoch"] + 1 + self.cur_step = checkpoint["cur_step"] + # self.best_valid_score = checkpoint["best_valid_score"] + + # load architecture params from checkpoint + if checkpoint["config"]["model"].lower() != self.config["model"].lower(): + self.logger.warning( + "Architecture configuration given in config file is different from that of checkpoint. " + "This may yield an exception while state_dict is being loaded." + ) + self.model.load_state_dict(checkpoint["state_dict"]) + self.model.load_other_parameter(checkpoint.get("other_parameter")) + + # load optimizer state from checkpoint only when optimizer type is not changed + self.optimizer.load_state_dict(checkpoint["optimizer"]) + message_output = "Checkpoint loaded. Resume training from epoch {}".format( + self.start_epoch + ) + self.logger.info(message_output) + + def _check_nan(self, loss): + if torch.isnan(loss): + raise ValueError("Training loss is nan") + + def _generate_train_loss_output(self, epoch_idx, s_time, e_time, losses): + des = self.config["loss_decimal_place"] or 4 + train_loss_output = ( + set_color(f"epoch {epoch_idx} training", "green") + + " [" + + set_color("time", "blue") + + f": {e_time - s_time:.2f}s, " + ) + # 遍历字典,格式化并添加每个损失项 + loss_items = [ + set_color(f"{key}", "blue") + f": {value:.{des}f}" + for key, value in losses.items() + ] + # 将所有损失项连接成一个字符串,并与前面的输出拼接 + train_loss_output += ", ".join(loss_items) + return train_loss_output + "]" + + def _add_hparam_to_tensorboard(self, best_valid_result): + # base hparam + hparam_dict = { + "learner": self.config["learner"], + "learning_rate": self.config["learning_rate"], + "train_batch_size": self.config["train_batch_size"], + } + # unrecorded parameter + unrecorded_parameter = { + parameter + for parameters in self.config.parameters.values() + for parameter in parameters + }.union({"model", "dataset", "config_files", "device"}) + # other model-specific hparam + hparam_dict.update( + { + para: val + for para, val in self.config.final_config_dict.items() + if para not in unrecorded_parameter + } + ) + for k in hparam_dict: + if hparam_dict[k] is not None and not isinstance( + hparam_dict[k], (bool, str, float, int) + ): + hparam_dict[k] = str(hparam_dict[k]) + + self.tensorboard.add_hparams( + hparam_dict, {"hparam/best_valid_result": best_valid_result} + ) + + def fit(self, train_data, valid_data=None, verbose=True, saved=True, show_progress=False, callback_fn=None): + r"""Train the model based on the train data and the valid data. + + Args: + train_data (DataLoader): the train data + valid_data (DataLoader, optional): the valid data, default: None. + If it's None, the early_stopping is invalid. + verbose (bool, optional): whether to write training and evaluation information to logger, default: True + saved (bool, optional): whether to save the model parameters, default: True + show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``. + callback_fn (callable): Optional callback function executed at end of epoch. + Includes (epoch_idx, valid_score) input arguments. + + Returns: + best result + """ + if saved and self.start_epoch >= self.epochs: + self._save_checkpoint(-1, verbose=verbose) + + if not (self.config['resume_checkpoint_path'] == None ) and self.config['resume']: + self.resume_checkpoint(self.config['resume_checkpoint_path']) + + for epoch_idx in range(self.start_epoch, self.epochs): + training_start_time = time() + train_loss = self._train_epoch(train_data, epoch_idx, show_progress=show_progress) + self.train_loss_dict[epoch_idx] = ( + sum(train_loss) if isinstance(train_loss, tuple) else train_loss + ) + training_end_time = time() + train_loss_output = self._generate_train_loss_output( + epoch_idx, training_start_time, training_end_time, train_loss) + + if verbose: + self.logger.info(train_loss_output) + # self._add_train_loss_to_tensorboard(epoch_idx, train_loss) + + if self.eval_step <= 0 or not valid_data: + if saved: + self._save_checkpoint(epoch_idx, verbose=verbose) + continue + + if (epoch_idx + 1) % self.eval_step == 0: + valid_start_time = time() + valid_loss = self._valid_epoch(valid_data=valid_data, show_progress=show_progress) + + (self.best_valid_score, self.cur_step, stop_flag,update_flag,) = early_stopping( + valid_loss['loss'], + self.best_valid_score, + self.cur_step, + max_step=self.stopping_step, + bigger=self.valid_metric_bigger, + ) + valid_end_time = time() + + valid_loss_output = ( + set_color("valid result", "blue") + ": \n" + dict2str(valid_loss) + ) + if verbose: + self.logger.info(valid_loss_output) + + + if update_flag: + if saved: + self._save_checkpoint(epoch_idx, verbose=verbose) + self.best_valid_result = valid_loss['loss'] + + if stop_flag: + stop_output = "Finished training, best eval result in epoch %d" % ( + epoch_idx - self.cur_step * self.eval_step + ) + if verbose: + self.logger.info(stop_output) + break + @torch.no_grad() + def evaluate(self, load_best_model=True, model_file=None): + """ + Evaluate the model based on the test data. + + args: load_best_model: bool, whether to load the best model in the training process. + model_file: str, the model file you want to evaluate. + + """ + if load_best_model: + checkpoint_file = model_file or self.saved_model_file + checkpoint = torch.load(checkpoint_file, map_location=self.device) + self.model.load_state_dict(checkpoint["state_dict"]) + self.model.load_other_parameter(checkpoint.get("other_parameter")) + message_output = "Loading model structure and parameters from {}".format( + checkpoint_file + ) + self.logger.info(message_output) + self.model.eval() + + datadict = self.model.generate_batch() + eval_result = self.evaluator.evaluate(datadict) + self.logger.info(eval_result) + + + +class Wav2LipTrainer(Trainer): + def __init__(self, config, model): + super(Wav2LipTrainer, self).__init__(config, model) + + def _train_epoch(self, train_data, epoch_idx, loss_func=None, show_progress=False): + r"""Train the model in an epoch + + Args: + train_data (DataLoader): The train data. + epoch_idx (int): The current epoch id. + loss_func (function): The loss function of :attr:`model`. If it is ``None``, the loss function will be + :attr:`self.model.calculate_loss`. Defaults to ``None``. + show_progress (bool): Show the progress of training epoch. Defaults to ``False``. + + Returns: + the averaged loss of this epoch + """ + self.model.train() + + + + loss_func = loss_func or self.model.calculate_loss + total_loss_dict = {} + step = 0 + iter_data = ( + tqdm( + train_data, + total=len(train_data), + ncols=None, + ) + if show_progress + else train_data + ) + + for batch_idx, interaction in enumerate(iter_data): + self.optimizer.zero_grad() + step += 1 + losses_dict = loss_func(interaction) + loss = losses_dict["loss"] + + for key, value in losses_dict.items(): + if key in total_loss_dict: + if not torch.is_tensor(value): + total_loss_dict[key] += value + # 如果键已经在总和字典中,累加当前值 + else: + losses_dict[key] = value.item() + total_loss_dict[key] += value.item() + else: + if not torch.is_tensor(value): + total_loss_dict[key] = value + # 否则,将当前值添加到字典中 + else: + losses_dict[key] = value.item() + total_loss_dict[key] = value.item() + iter_data.set_description(set_color(f"train {epoch_idx} {losses_dict}", "pink")) + + self._check_nan(loss) + loss.backward() + self.optimizer.step() + average_loss_dict = {} + for key, value in total_loss_dict.items(): + average_loss_dict[key] = value/step + + return average_loss_dict + + + + def _valid_epoch(self, valid_data, loss_func=None, show_progress=False): + print('Valid'.format(self.eval_step)) + self.model.eval() + total_loss_dict = {} + iter_data = ( + tqdm(valid_data, + total=len(valid_data), + ncols=None, + desc=set_color("Valid", "pink") + ) + if show_progress + else valid_data + ) + step = 0 + for batch_idx, batched_data in enumerate(iter_data): + step += 1 + losses_dict = self.model.calculate_loss(batched_data, valid=True) + for key, value in losses_dict.items(): + if key in total_loss_dict: + if not torch.is_tensor(value): + total_loss_dict[key] += value + # 如果键已经在总和字典中,累加当前值 + else: + losses_dict[key] = value.item() + total_loss_dict[key] += value.item() + else: + if not torch.is_tensor(value): + total_loss_dict[key] = value + # 否则,将当前值添加到字典中 + else: + losses_dict[key] = value.item() + total_loss_dict[key] = value.item() + average_loss_dict = {} + for key, value in total_loss_dict.items(): + average_loss_dict[key] = value/step + if losses_dict["sync_loss"] < .75: + self.model.config["syncnet_wt"] = 0.01 + return average_loss_dict + \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/utils/__init__.py b/talkingface-toolkit-main/talkingface/utils/__init__.py new file mode 100644 index 00000000..712ba862 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/__init__.py @@ -0,0 +1,43 @@ +from talkingface.utils.logger import init_logger, set_color +from talkingface.utils.enum_type import * +from talkingface.utils.utils import ( + get_local_time, + ensure_dir, + get_model, + get_trainer, + get_environment, + early_stopping, + calculate_valid_score, + dict2str, + init_seed, + get_tensorboard, + get_gpu_usage, + get_flops, + list_to_latex, + get_preprocess, + create_dataset +) + +from talkingface.utils.argument_list import * +from talkingface.utils.wandblogger import WandbLogger +__all__ = [ + "init_logger", + "get_local_time", + "ensure_dir", + "get_model", + "get_trainer", + "early_stopping", + "calculate_valid_score", + "dict2str", + "init_seed", + "general_arguments", + "training_arguments", + "evaluation_arguments", + "get_tensorboard", + "set_color", + "get_gpu_usage", + "get_flops", + "get_environment", + "list_to_latex", + "WandbLogger", +] diff --git a/talkingface-toolkit-main/talkingface/utils/argument_list.py b/talkingface-toolkit-main/talkingface/utils/argument_list.py new file mode 100644 index 00000000..01a7dad3 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/argument_list.py @@ -0,0 +1,48 @@ +general_arguments = [ + 'gpu_id', 'use_gpu', + 'seed', + 'reproducibility', + 'state', + 'checkpoint_dir', + 'show_progress', + 'config_file', + 'log_wandb', +] + +training_arguments = [ + 'epochs', 'train_batch_size', + 'learner', 'learning_rate', + 'eval_step', 'stopping_step', + 'weight_decay', 'resume' + 'train' +] + +evaluation_arguments = [ + 'metrics', + 'temp_dir', + 'evaluate_batch_size', + 'lse_checkpoint_path', + 'valid_metric_bigger', +] +# evaluation_arguments = [ +# 'eval_args', 'repeatable', +# 'metrics', 'topk', 'valid_metric', 'valid_metric_bigger', +# 'eval_batch_size', +# 'metric_decimal_place', +# ] + +# dataset_arguments = [ +# 'field_separator', 'seq_separator', +# 'USER_ID_FIELD', 'ITEM_ID_FIELD', 'RATING_FIELD', 'TIME_FIELD', +# 'seq_len', +# 'LABEL_FIELD', 'threshold', +# 'NEG_PREFIX', +# 'ITEM_LIST_LENGTH_FIELD', 'LIST_SUFFIX', 'MAX_ITEM_LIST_LENGTH', 'POSITION_FIELD', +# 'HEAD_ENTITY_ID_FIELD', 'TAIL_ENTITY_ID_FIELD', 'RELATION_ID_FIELD', 'ENTITY_ID_FIELD', +# 'load_col', 'unload_col', 'unused_col', 'additional_feat_suffix', +# 'rm_dup_inter', 'val_interval', 'filter_inter_by_user_or_item', +# 'user_inter_num_interval', 'item_inter_num_interval', +# 'alias_of_user_id', 'alias_of_item_id', 'alias_of_entity_id', 'alias_of_relation_id', +# 'preload_weight', 'normalize_field', 'normalize_all', +# 'benchmark_filename', +# ] diff --git a/talkingface-toolkit-main/talkingface/utils/data_process.py b/talkingface-toolkit-main/talkingface/utils/data_process.py new file mode 100644 index 00000000..cbc430ac --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/data_process.py @@ -0,0 +1,95 @@ +import os +import cv2 +import numpy as np +import subprocess +from tqdm import tqdm +from glob import glob +from concurrent.futures import ThreadPoolExecutor, as_completed +from talkingface.utils import face_detection +import traceback +import librosa +import librosa.filters +from scipy import signal +from scipy.io import wavfile + + +class lrs2Preprocess: + def __init__(self, config): + self.config = config + self.fa = [face_detection.FaceAlignment(face_detection.LandmarksType._2D, flip_input=False, + device=f'cuda:{id}') for id in range(config['ngpu'])] + self.template = 'ffmpeg -loglevel panic -y -i {} -strict -2 {}' + + def process_video_file(self, vfile, gpu_id): + video_stream = cv2.VideoCapture(vfile) + + frames = [] + while 1: + still_reading, frame = video_stream.read() + if not still_reading: + video_stream.release() + break + frames.append(frame) + + vidname = os.path.basename(vfile).split('.')[0] + dirname = vfile.split('/')[-2] + + fulldir = os.path.join(self.config['preprocessed_root'], dirname, vidname) + os.makedirs(fulldir, exist_ok=True) + + batches = [frames[i:i + self.config['preprocess_batch_size']] for i in range(0, len(frames), self.config['preprocess_batch_size'])] + + i = -1 + for fb in batches: + preds = self.fa[gpu_id].get_detections_for_batch(np.asarray(fb)) + + for j, f in enumerate(preds): + i += 1 + if f is None: + continue + + x1, y1, x2, y2 = f + cv2.imwrite(os.path.join(fulldir, '{}.jpg'.format(i)), fb[j][y1:y2, x1:x2]) + + + def process_audio_file(self, vfile): + vidname = os.path.basename(vfile).split('.')[0] + dirname = vfile.split('/')[-2] + + fulldir = os.path.join(self.config['preprocessed_root'], dirname, vidname) + os.makedirs(fulldir, exist_ok=True) + + wavpath = os.path.join(fulldir, 'audio.wav') + + command =self.template.format(vfile, wavpath) + subprocess.call(command, shell=True) + + def mp_handler(self, job): + vfile, gpu_id = job + try: + self.process_video_file(vfile, gpu_id) + except KeyboardInterrupt: + exit(0) + except: + traceback.print_exc() + + def run(self): + print(f'Started processing for {self.config["data_root"]} with {self.config["ngpu"]} GPUs') + + filelist = glob(os.path.join(self.config["data_root"], '*/*.mp4')) + + # jobs = [(vfile, i % self.config["ngpu"]) for i, vfile in enumerate(filelist)] + # with ThreadPoolExecutor(self.config["ngpu"]) as p: + # futures = [p.submit(self.mp_handler, j) for j in jobs] + # _ = [r.result() for r in tqdm(as_completed(futures), total=len(futures))] + + print('Dumping audios...') + for vfile in tqdm(filelist): + try: + self.process_audio_file(vfile) + except KeyboardInterrupt: + exit(0) + except: + traceback.print_exc() + continue + diff --git a/talkingface-toolkit-main/talkingface/utils/enum_type.py b/talkingface-toolkit-main/talkingface/utils/enum_type.py new file mode 100644 index 00000000..b701b38b --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/enum_type.py @@ -0,0 +1,13 @@ +from enum import Enum + +class EvaluatorType(Enum): + """Type for evaluation metrics. + + - ``SYNC``: SYNC metrics like LSE-C, LSE-D, etc. + - ``VIDEOQ``: Video quality metrics like FID, etc. + - ``AUDIOQ``: Audio quality metrics like PESQ, etc. + """ + + SYNC = 1 + VIDEOQ = 2 + AUDIOQ = 3 \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/README.md b/talkingface-toolkit-main/talkingface/utils/face_detection/README.md new file mode 100644 index 00000000..c073376e --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/README.md @@ -0,0 +1 @@ +The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time. \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/__init__.py b/talkingface-toolkit-main/talkingface/utils/face_detection/__init__.py new file mode 100644 index 00000000..4bae29fd --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/__init__.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +__author__ = """Adrian Bulat""" +__email__ = 'adrian.bulat@nottingham.ac.uk' +__version__ = '1.0.1' + +from .api import FaceAlignment, LandmarksType, NetworkSize diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/api.py b/talkingface-toolkit-main/talkingface/utils/face_detection/api.py new file mode 100644 index 00000000..dc4dd4f2 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/api.py @@ -0,0 +1,79 @@ +from __future__ import print_function +import os +import torch +from torch.utils.model_zoo import load_url +from enum import Enum +import numpy as np +import cv2 +try: + import urllib.request as request_file +except BaseException: + import urllib as request_file + +# from .models import FAN, ResNetDepth +# from .utils import * + + +class LandmarksType(Enum): + """Enum class defining the type of landmarks to detect. + + ``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face + ``_2halfD`` - this points represent the projection of the 3D points into 3D + ``_3D`` - detect the points ``(x,y,z)``` in a 3D space + + """ + _2D = 1 + _2halfD = 2 + _3D = 3 + + +class NetworkSize(Enum): + # TINY = 1 + # SMALL = 2 + # MEDIUM = 3 + LARGE = 4 + + def __new__(cls, value): + member = object.__new__(cls) + member._value_ = value + return member + + def __int__(self): + return self.value + +ROOT = os.path.dirname(os.path.abspath(__file__)) + +class FaceAlignment: + def __init__(self, landmarks_type, network_size=NetworkSize.LARGE, + device='cuda', flip_input=False, face_detector='sfd', verbose=False): + self.device = device + self.flip_input = flip_input + self.landmarks_type = landmarks_type + self.verbose = verbose + + network_size = int(network_size) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + # Get the face detector + face_detector_module = __import__('talkingface.utils.face_detection.detection.' + face_detector, + globals(), locals(), [face_detector], 0) + self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose) + + def get_detections_for_batch(self, images): + images = images[..., ::-1] + detected_faces = self.face_detector.detect_from_batch(images.copy()) + results = [] + + for i, d in enumerate(detected_faces): + if len(d) == 0: + results.append(None) + continue + d = d[0] + d = np.clip(d, 0, None) + + x1, y1, x2, y2 = map(int, d[:-1]) + results.append((x1, y1, x2, y2)) + + return results \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/__init__.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/__init__.py new file mode 100644 index 00000000..1a6b0402 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/__init__.py @@ -0,0 +1 @@ +from .core import FaceDetector \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/core.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/core.py new file mode 100644 index 00000000..0f8275e8 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/core.py @@ -0,0 +1,130 @@ +import logging +import glob +from tqdm import tqdm +import numpy as np +import torch +import cv2 + + +class FaceDetector(object): + """An abstract class representing a face detector. + + Any other face detection implementation must subclass it. All subclasses + must implement ``detect_from_image``, that return a list of detected + bounding boxes. Optionally, for speed considerations detect from path is + recommended. + """ + + def __init__(self, device, verbose): + self.device = device + self.verbose = verbose + + if verbose: + if 'cpu' in device: + logger = logging.getLogger(__name__) + logger.warning("Detection running on CPU, this may be potentially slow.") + + if 'cpu' not in device and 'cuda' not in device: + if verbose: + logger.error("Expected values for device are: {cpu, cuda} but got: %s", device) + raise ValueError + + def detect_from_image(self, tensor_or_path): + """Detects faces in a given image. + + This function detects the faces present in a provided BGR(usually) + image. The input can be either the image itself or the path to it. + + Arguments: + tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path + to an image or the image itself. + + Example:: + + >>> path_to_image = 'data/image_01.jpg' + ... detected_faces = detect_from_image(path_to_image) + [A list of bounding boxes (x1, y1, x2, y2)] + >>> image = cv2.imread(path_to_image) + ... detected_faces = detect_from_image(image) + [A list of bounding boxes (x1, y1, x2, y2)] + + """ + raise NotImplementedError + + def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True): + """Detects faces from all the images present in a given directory. + + Arguments: + path {string} -- a string containing a path that points to the folder containing the images + + Keyword Arguments: + extensions {list} -- list of string containing the extensions to be + consider in the following format: ``.extension_name`` (default: + {['.jpg', '.png']}) recursive {bool} -- option wherever to scan the + folder recursively (default: {False}) show_progress_bar {bool} -- + display a progressbar (default: {True}) + + Example: + >>> directory = 'data' + ... detected_faces = detect_from_directory(directory) + {A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]} + + """ + if self.verbose: + logger = logging.getLogger(__name__) + + if len(extensions) == 0: + if self.verbose: + logger.error("Expected at list one extension, but none was received.") + raise ValueError + + if self.verbose: + logger.info("Constructing the list of images.") + additional_pattern = '/**/*' if recursive else '/*' + files = [] + for extension in extensions: + files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive)) + + if self.verbose: + logger.info("Finished searching for images. %s images found", len(files)) + logger.info("Preparing to run the detection.") + + predictions = {} + for image_path in tqdm(files, disable=not show_progress_bar): + if self.verbose: + logger.info("Running the face detector on image: %s", image_path) + predictions[image_path] = self.detect_from_image(image_path) + + if self.verbose: + logger.info("The detector was successfully run on all %s images", len(files)) + + return predictions + + @property + def reference_scale(self): + raise NotImplementedError + + @property + def reference_x_shift(self): + raise NotImplementedError + + @property + def reference_y_shift(self): + raise NotImplementedError + + @staticmethod + def tensor_or_path_to_ndarray(tensor_or_path, rgb=True): + """Convert path (represented as a string) or torch.tensor to a numpy.ndarray + + Arguments: + tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself + """ + if isinstance(tensor_or_path, str): + return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1] + elif torch.is_tensor(tensor_or_path): + # Call cpu in case its coming from cuda + return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy() + elif isinstance(tensor_or_path, np.ndarray): + return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path + else: + raise TypeError diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/__init__.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/__init__.py new file mode 100644 index 00000000..5a63ecd4 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/__init__.py @@ -0,0 +1 @@ +from .sfd_detector import SFDDetector as FaceDetector \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/bbox.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/bbox.py new file mode 100644 index 00000000..4bd7222e --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/bbox.py @@ -0,0 +1,129 @@ +from __future__ import print_function +import os +import sys +import cv2 +import random +import datetime +import time +import math +import argparse +import numpy as np +import torch + +try: + from iou import IOU +except BaseException: + # IOU cython speedup 10x + def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2): + sa = abs((ax2 - ax1) * (ay2 - ay1)) + sb = abs((bx2 - bx1) * (by2 - by1)) + x1, y1 = max(ax1, bx1), max(ay1, by1) + x2, y2 = min(ax2, bx2), min(ay2, by2) + w = x2 - x1 + h = y2 - y1 + if w < 0 or h < 0: + return 0.0 + else: + return 1.0 * w * h / (sa + sb - w * h) + + +def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh): + xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1 + dx, dy = (xc - axc) / aww, (yc - ayc) / ahh + dw, dh = math.log(ww / aww), math.log(hh / ahh) + return dx, dy, dw, dh + + +def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh): + xc, yc = dx * aww + axc, dy * ahh + ayc + ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh + x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2 + return x1, y1, x2, y2 + + +def nms(dets, thresh): + if 0 == len(dets): + return [] + x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4] + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]]) + xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]]) + + w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1) + ovr = w * h / (areas[i] + areas[order[1:]] - w * h) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + +def batch_decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2) + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/detect.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/detect.py new file mode 100644 index 00000000..efef6273 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/detect.py @@ -0,0 +1,112 @@ +import torch +import torch.nn.functional as F + +import os +import sys +import cv2 +import random +import datetime +import math +import argparse +import numpy as np + +import scipy.io as sio +import zipfile +from .net_s3fd import s3fd +from .bbox import * + + +def detect(net, img, device): + img = img - np.array([104, 117, 123]) + img = img.transpose(2, 0, 1) + img = img.reshape((1,) + img.shape) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + img = torch.from_numpy(img).float().to(device) + BB, CC, HH, WW = img.size() + with torch.no_grad(): + olist = net(img) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], dim=1) + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.size() # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[0, 1, hindex, windex] + loc = oreg[0, :, hindex, windex].contiguous().view(1, 4) + priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]) + variances = [0.1, 0.2] + box = decode(loc, priors, variances) + x1, y1, x2, y2 = box[0] * 1.0 + # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) + bboxlist.append([x1, y1, x2, y2, score]) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, 5)) + + return bboxlist + +def batch_detect(net, imgs, device): + imgs = imgs - np.array([104, 117, 123]) + imgs = imgs.transpose(0, 3, 1, 2) + + if 'cuda' in device: + torch.backends.cudnn.benchmark = True + + imgs = torch.from_numpy(imgs).float().to(device) + BB, CC, HH, WW = imgs.size() + with torch.no_grad(): + olist = net(imgs) + + bboxlist = [] + for i in range(len(olist) // 2): + olist[i * 2] = F.softmax(olist[i * 2], dim=1) + olist = [oelem.data.cpu() for oelem in olist] + for i in range(len(olist) // 2): + ocls, oreg = olist[i * 2], olist[i * 2 + 1] + FB, FC, FH, FW = ocls.size() # feature map size + stride = 2**(i + 2) # 4,8,16,32,64,128 + anchor = stride * 4 + poss = zip(*np.where(ocls[:, 1, :, :] > 0.05)) + for Iindex, hindex, windex in poss: + axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride + score = ocls[:, 1, hindex, windex] + loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4) + priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4) + variances = [0.1, 0.2] + box = batch_decode(loc, priors, variances) + box = box[:, 0] * 1.0 + # cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1) + bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy()) + bboxlist = np.array(bboxlist) + if 0 == len(bboxlist): + bboxlist = np.zeros((1, BB, 5)) + + return bboxlist + +def flip_detect(net, img, device): + img = cv2.flip(img, 1) + b = detect(net, img, device) + + bboxlist = np.zeros(b.shape) + bboxlist[:, 0] = img.shape[1] - b[:, 2] + bboxlist[:, 1] = b[:, 1] + bboxlist[:, 2] = img.shape[1] - b[:, 0] + bboxlist[:, 3] = b[:, 3] + bboxlist[:, 4] = b[:, 4] + return bboxlist + + +def pts_to_bb(pts): + min_x, min_y = np.min(pts, axis=0) + max_x, max_y = np.max(pts, axis=0) + return np.array([min_x, min_y, max_x, max_y]) diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/net_s3fd.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/net_s3fd.py new file mode 100644 index 00000000..fc64313c --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/net_s3fd.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class L2Norm(nn.Module): + def __init__(self, n_channels, scale=1.0): + super(L2Norm, self).__init__() + self.n_channels = n_channels + self.scale = scale + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.weight.data *= 0.0 + self.weight.data += self.scale + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps + x = x / norm * self.weight.view(1, -1, 1, 1) + return x + + +class s3fd(nn.Module): + def __init__(self): + super(s3fd, self).__init__() + self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3) + self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0) + + self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) + self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1) + + self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0) + self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1) + + self.conv3_3_norm = L2Norm(256, scale=10) + self.conv4_3_norm = L2Norm(512, scale=8) + self.conv5_3_norm = L2Norm(512, scale=5) + + self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1) + self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1) + self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1) + self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + h = F.relu(self.conv1_1(x)) + h = F.relu(self.conv1_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv2_1(h)) + h = F.relu(self.conv2_2(h)) + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv3_1(h)) + h = F.relu(self.conv3_2(h)) + h = F.relu(self.conv3_3(h)) + f3_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv4_1(h)) + h = F.relu(self.conv4_2(h)) + h = F.relu(self.conv4_3(h)) + f4_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.conv5_1(h)) + h = F.relu(self.conv5_2(h)) + h = F.relu(self.conv5_3(h)) + f5_3 = h + h = F.max_pool2d(h, 2, 2) + + h = F.relu(self.fc6(h)) + h = F.relu(self.fc7(h)) + ffc7 = h + h = F.relu(self.conv6_1(h)) + h = F.relu(self.conv6_2(h)) + f6_2 = h + h = F.relu(self.conv7_1(h)) + h = F.relu(self.conv7_2(h)) + f7_2 = h + + f3_3 = self.conv3_3_norm(f3_3) + f4_3 = self.conv4_3_norm(f4_3) + f5_3 = self.conv5_3_norm(f5_3) + + cls1 = self.conv3_3_norm_mbox_conf(f3_3) + reg1 = self.conv3_3_norm_mbox_loc(f3_3) + cls2 = self.conv4_3_norm_mbox_conf(f4_3) + reg2 = self.conv4_3_norm_mbox_loc(f4_3) + cls3 = self.conv5_3_norm_mbox_conf(f5_3) + reg3 = self.conv5_3_norm_mbox_loc(f5_3) + cls4 = self.fc7_mbox_conf(ffc7) + reg4 = self.fc7_mbox_loc(ffc7) + cls5 = self.conv6_2_mbox_conf(f6_2) + reg5 = self.conv6_2_mbox_loc(f6_2) + cls6 = self.conv7_2_mbox_conf(f7_2) + reg6 = self.conv7_2_mbox_loc(f7_2) + + # max-out background label + chunk = torch.chunk(cls1, 4, 1) + bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2]) + cls1 = torch.cat([bmax, chunk[3]], dim=1) + + return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6] diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/sfd_detector.py b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/sfd_detector.py new file mode 100644 index 00000000..8fbce152 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/detection/sfd/sfd_detector.py @@ -0,0 +1,59 @@ +import os +import cv2 +from torch.utils.model_zoo import load_url + +from ..core import FaceDetector + +from .net_s3fd import s3fd +from .bbox import * +from .detect import * + +models_urls = { + 's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth', +} + + +class SFDDetector(FaceDetector): + def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False): + super(SFDDetector, self).__init__(device, verbose) + + # Initialise the face detector + if not os.path.isfile(path_to_detector): + model_weights = load_url(models_urls['s3fd']) + else: + model_weights = torch.load(path_to_detector) + + self.face_detector = s3fd() + self.face_detector.load_state_dict(model_weights) + self.face_detector.to(device) + self.face_detector.eval() + + def detect_from_image(self, tensor_or_path): + image = self.tensor_or_path_to_ndarray(tensor_or_path) + + bboxlist = detect(self.face_detector, image, device=self.device) + keep = nms(bboxlist, 0.3) + bboxlist = bboxlist[keep, :] + bboxlist = [x for x in bboxlist if x[-1] > 0.5] + + return bboxlist + + def detect_from_batch(self, images): + bboxlists = batch_detect(self.face_detector, images, device=self.device) + keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])] + bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)] + bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists] + + return bboxlists + + @property + def reference_scale(self): + return 195 + + @property + def reference_x_shift(self): + return 0 + + @property + def reference_y_shift(self): + return 0 diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/models.py b/talkingface-toolkit-main/talkingface/utils/face_detection/models.py new file mode 100644 index 00000000..ee2dde32 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/models.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import math + + +def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False): + "3x3 convolution with padding" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, + stride=strd, padding=padding, bias=bias) + + +class ConvBlock(nn.Module): + def __init__(self, in_planes, out_planes): + super(ConvBlock, self).__init__() + self.bn1 = nn.BatchNorm2d(in_planes) + self.conv1 = conv3x3(in_planes, int(out_planes / 2)) + self.bn2 = nn.BatchNorm2d(int(out_planes / 2)) + self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) + self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) + self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) + + if in_planes != out_planes: + self.downsample = nn.Sequential( + nn.BatchNorm2d(in_planes), + nn.ReLU(True), + nn.Conv2d(in_planes, out_planes, + kernel_size=1, stride=1, bias=False), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out1 = self.bn1(x) + out1 = F.relu(out1, True) + out1 = self.conv1(out1) + + out2 = self.bn2(out1) + out2 = F.relu(out2, True) + out2 = self.conv2(out2) + + out3 = self.bn3(out2) + out3 = F.relu(out3, True) + out3 = self.conv3(out3) + + out3 = torch.cat((out1, out2, out3), 1) + + if self.downsample is not None: + residual = self.downsample(residual) + + out3 += residual + + return out3 + + +class Bottleneck(nn.Module): + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HourGlass(nn.Module): + def __init__(self, num_modules, depth, num_features): + super(HourGlass, self).__init__() + self.num_modules = num_modules + self.depth = depth + self.features = num_features + + self._generate_network(self.depth) + + def _generate_network(self, level): + self.add_module('b1_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) + + if level > 1: + self._generate_network(level - 1) + else: + self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) + + self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) + + def _forward(self, level, inp): + # Upper branch + up1 = inp + up1 = self._modules['b1_' + str(level)](up1) + + # Lower branch + low1 = F.avg_pool2d(inp, 2, stride=2) + low1 = self._modules['b2_' + str(level)](low1) + + if level > 1: + low2 = self._forward(level - 1, low1) + else: + low2 = low1 + low2 = self._modules['b2_plus_' + str(level)](low2) + + low3 = low2 + low3 = self._modules['b3_' + str(level)](low3) + + up2 = F.interpolate(low3, scale_factor=2, mode='nearest') + + return up1 + up2 + + def forward(self, x): + return self._forward(self.depth, x) + + +class FAN(nn.Module): + + def __init__(self, num_modules=1): + super(FAN, self).__init__() + self.num_modules = num_modules + + # Base part + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.conv2 = ConvBlock(64, 128) + self.conv3 = ConvBlock(128, 128) + self.conv4 = ConvBlock(128, 256) + + # Stacking part + for hg_module in range(self.num_modules): + self.add_module('m' + str(hg_module), HourGlass(1, 4, 256)) + self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256)) + self.add_module('conv_last' + str(hg_module), + nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256)) + self.add_module('l' + str(hg_module), nn.Conv2d(256, + 68, kernel_size=1, stride=1, padding=0)) + + if hg_module < self.num_modules - 1: + self.add_module( + 'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0)) + self.add_module('al' + str(hg_module), nn.Conv2d(68, + 256, kernel_size=1, stride=1, padding=0)) + + def forward(self, x): + x = F.relu(self.bn1(self.conv1(x)), True) + x = F.avg_pool2d(self.conv2(x), 2, stride=2) + x = self.conv3(x) + x = self.conv4(x) + + previous = x + + outputs = [] + for i in range(self.num_modules): + hg = self._modules['m' + str(i)](previous) + + ll = hg + ll = self._modules['top_m_' + str(i)](ll) + + ll = F.relu(self._modules['bn_end' + str(i)] + (self._modules['conv_last' + str(i)](ll)), True) + + # Predict heatmaps + tmp_out = self._modules['l' + str(i)](ll) + outputs.append(tmp_out) + + if i < self.num_modules - 1: + ll = self._modules['bl' + str(i)](ll) + tmp_out_ = self._modules['al' + str(i)](tmp_out) + previous = previous + ll + tmp_out_ + + return outputs + + +class ResNetDepth(nn.Module): + + def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68): + self.inplanes = 64 + super(ResNetDepth, self).__init__() + self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AvgPool2d(7) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x diff --git a/talkingface-toolkit-main/talkingface/utils/face_detection/utils.py b/talkingface-toolkit-main/talkingface/utils/face_detection/utils.py new file mode 100644 index 00000000..3dc4cf3e --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/face_detection/utils.py @@ -0,0 +1,313 @@ +from __future__ import print_function +import os +import sys +import time +import torch +import math +import numpy as np +import cv2 + + +def _gaussian( + size=3, sigma=0.25, amplitude=1, normalize=False, width=None, + height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5, + mean_vert=0.5): + # handle some defaults + if width is None: + width = size + if height is None: + height = size + if sigma_horz is None: + sigma_horz = sigma + if sigma_vert is None: + sigma_vert = sigma + center_x = mean_horz * width + 0.5 + center_y = mean_vert * height + 0.5 + gauss = np.empty((height, width), dtype=np.float32) + # generate kernel + for i in range(height): + for j in range(width): + gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / ( + sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0)) + if normalize: + gauss = gauss / np.sum(gauss) + return gauss + + +def draw_gaussian(image, point, sigma): + # Check if the gaussian is inside + ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)] + br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)] + if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1): + return image + size = 6 * sigma + 1 + g = _gaussian(size) + g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))] + g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))] + img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))] + img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))] + assert (g_x[0] > 0 and g_y[1] > 0) + image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1] + ] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]] + image[image > 1] = 1 + return image + + +def transform(point, center, scale, resolution, invert=False): + """Generate and affine transformation matrix. + + Given a set of points, a center, a scale and a targer resolution, the + function generates and affine transformation matrix. If invert is ``True`` + it will produce the inverse transformation. + + Arguments: + point {torch.tensor} -- the input 2D point + center {torch.tensor or numpy.array} -- the center around which to perform the transformations + scale {float} -- the scale of the face/object + resolution {float} -- the output resolution + + Keyword Arguments: + invert {bool} -- define wherever the function should produce the direct or the + inverse transformation matrix (default: {False}) + """ + _pt = torch.ones(3) + _pt[0] = point[0] + _pt[1] = point[1] + + h = 200.0 * scale + t = torch.eye(3) + t[0, 0] = resolution / h + t[1, 1] = resolution / h + t[0, 2] = resolution * (-center[0] / h + 0.5) + t[1, 2] = resolution * (-center[1] / h + 0.5) + + if invert: + t = torch.inverse(t) + + new_point = (torch.matmul(t, _pt))[0:2] + + return new_point.int() + + +def crop(image, center, scale, resolution=256.0): + """Center crops an image or set of heatmaps + + Arguments: + image {numpy.array} -- an rgb image + center {numpy.array} -- the center of the object, usually the same as of the bounding box + scale {float} -- scale of the face + + Keyword Arguments: + resolution {float} -- the size of the output cropped image (default: {256.0}) + + Returns: + [type] -- [description] + """ # Crop around the center point + """ Crops the image around the center. Input is expected to be an np.ndarray """ + ul = transform([1, 1], center, scale, resolution, True) + br = transform([resolution, resolution], center, scale, resolution, True) + # pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0) + if image.ndim > 2: + newDim = np.array([br[1] - ul[1], br[0] - ul[0], + image.shape[2]], dtype=np.int32) + newImg = np.zeros(newDim, dtype=np.uint8) + else: + newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int) + newImg = np.zeros(newDim, dtype=np.uint8) + ht = image.shape[0] + wd = image.shape[1] + newX = np.array( + [max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32) + newY = np.array( + [max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32) + oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32) + oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32) + newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1] + ] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :] + newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), + interpolation=cv2.INTER_LINEAR) + return newImg + + +def get_preds_fromhm(hm, center=None, scale=None): + """Obtain (x,y) coordinates given a set of N heatmaps. If the center + and the scale is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + center {torch.tensor} -- the center of the bounding box (default: {None}) + scale {float} -- face scale (default: {None}) + """ + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) + + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) + + preds.add_(-.5) + + preds_orig = torch.zeros(preds.size()) + if center is not None and scale is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = transform( + preds[i, j], center, scale, hm.size(2), True) + + return preds, preds_orig + +def get_preds_fromhm_batch(hm, centers=None, scales=None): + """Obtain (x,y) coordinates given a set of N heatmaps. If the centers + and the scales is provided the function will return the points also in + the original coordinate frame. + + Arguments: + hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H] + + Keyword Arguments: + centers {torch.tensor} -- the centers of the bounding box (default: {None}) + scales {float} -- face scales (default: {None}) + """ + max, idx = torch.max( + hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2) + idx += 1 + preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float() + preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1) + preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1) + + for i in range(preds.size(0)): + for j in range(preds.size(1)): + hm_ = hm[i, j, :] + pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1 + if pX > 0 and pX < 63 and pY > 0 and pY < 63: + diff = torch.FloatTensor( + [hm_[pY, pX + 1] - hm_[pY, pX - 1], + hm_[pY + 1, pX] - hm_[pY - 1, pX]]) + preds[i, j].add_(diff.sign_().mul_(.25)) + + preds.add_(-.5) + + preds_orig = torch.zeros(preds.size()) + if centers is not None and scales is not None: + for i in range(hm.size(0)): + for j in range(hm.size(1)): + preds_orig[i, j] = transform( + preds[i, j], centers[i], scales[i], hm.size(2), True) + + return preds, preds_orig + +def shuffle_lr(parts, pairs=None): + """Shuffle the points left-right according to the axis of symmetry + of the object. + + Arguments: + parts {torch.tensor} -- a 3D or 4D object containing the + heatmaps. + + Keyword Arguments: + pairs {list of integers} -- [order of the flipped points] (default: {None}) + """ + if pairs is None: + pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, + 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35, + 34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41, + 40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63, + 62, 61, 60, 67, 66, 65] + if parts.ndimension() == 3: + parts = parts[pairs, ...] + else: + parts = parts[:, pairs, ...] + + return parts + + +def flip(tensor, is_label=False): + """Flip an image or a set of heatmaps left-right + + Arguments: + tensor {numpy.array or torch.tensor} -- [the input image or heatmaps] + + Keyword Arguments: + is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False}) + """ + if not torch.is_tensor(tensor): + tensor = torch.from_numpy(tensor) + + if is_label: + tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1) + else: + tensor = tensor.flip(tensor.ndimension() - 1) + + return tensor + +# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py) + + +def appdata_dir(appname=None, roaming=False): + """ appdata_dir(appname=None, roaming=False) + + Get the path to the application directory, where applications are allowed + to write user specific files (e.g. configurations). For non-user specific + data, consider using common_appdata_dir(). + If appname is given, a subdir is appended (and created if necessary). + If roaming is True, will prefer a roaming directory (Windows Vista/7). + """ + + # Define default user directory + userDir = os.getenv('FACEALIGNMENT_USERDIR', None) + if userDir is None: + userDir = os.path.expanduser('~') + if not os.path.isdir(userDir): # pragma: no cover + userDir = '/var/tmp' # issue #54 + + # Get system app data dir + path = None + if sys.platform.startswith('win'): + path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA') + path = (path2 or path1) if roaming else (path1 or path2) + elif sys.platform.startswith('darwin'): + path = os.path.join(userDir, 'Library', 'Application Support') + # On Linux and as fallback + if not (path and os.path.isdir(path)): + path = userDir + + # Maybe we should store things local to the executable (in case of a + # portable distro or a frozen application that wants to be portable) + prefix = sys.prefix + if getattr(sys, 'frozen', None): + prefix = os.path.abspath(os.path.dirname(sys.executable)) + for reldir in ('settings', '../settings'): + localpath = os.path.abspath(os.path.join(prefix, reldir)) + if os.path.isdir(localpath): # pragma: no cover + try: + open(os.path.join(localpath, 'test.write'), 'wb').close() + os.remove(os.path.join(localpath, 'test.write')) + except IOError: + pass # We cannot write in this directory + else: + path = localpath + break + + # Get path specific for this app + if appname: + if path == userDir: + appname = '.' + appname.lstrip('.') # Make it a hidden directory + path = os.path.join(path, appname) + if not os.path.isdir(path): # pragma: no cover + os.mkdir(path) + + # Done + return path diff --git a/talkingface-toolkit-main/talkingface/utils/logger.py b/talkingface-toolkit-main/talkingface/utils/logger.py new file mode 100644 index 00000000..855dbb94 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/logger.py @@ -0,0 +1,95 @@ +import logging +import os +import colorlog +import re +import hashlib +from talkingface.utils.utils import get_local_time, ensure_dir +from colorama import init + +log_colors_config = { + "DEBUG": "cyan", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red", +} + +class RemoveColorFilter(logging.Filter): + def filter(self, record): + if record: + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + record.msg = ansi_escape.sub("", str(record.msg)) + return True + +def set_color(log, color, highlight=True): + color_set = ["black", "red", "green", "yellow", "blue", "pink", "cyan", "white"] + try: + index = color_set.index(color) + except: + index = len(color_set) - 1 + prev_log = "\033[" + if highlight: + prev_log += "1;3" + else: + prev_log += "0;3" + prev_log += str(index) + "m" + return prev_log + log + "\033[0m" + +def init_logger(config): + """ + A logger that can show a message on standard output and write it into the + file named `filename` simultaneously. + All the message that you want to log MUST be str. + + Args: + config (Config): An instance object of Config, used to record parameter information. + + Example: + >>> logger = logging.getLogger(config) + >>> logger.debug(train_state) + >>> logger.info(train_result) + """ + init(autoreset=True) + LOGROOT = "./log/" + dir_name = os.path.dirname(LOGROOT) + ensure_dir(dir_name) + model_name = os.path.join(dir_name, config["model"]) + ensure_dir(model_name) + config_str = "".join([str(key) for key in config.final_config_dict.values()]) + md5 = hashlib.md5(config_str.encode(encoding="utf-8")).hexdigest()[:6] + logfilename = "{}/{}-{}-{}-{}.log".format( + config["model"], config["model"], config["dataset"], get_local_time(), md5 + ) + + logfilepath = os.path.join(LOGROOT, logfilename) + + filefmt = "%(asctime)-15s %(levelname)s %(message)s" + filedatefmt = "%a %d %b %Y %H:%M:%S" + fileformatter = logging.Formatter(filefmt, filedatefmt) + + sfmt = "%(log_color)s%(asctime)-15s %(levelname)s %(message)s" + sdatefmt = "%d %b %H:%M" + sformatter = colorlog.ColoredFormatter(sfmt, sdatefmt, log_colors=log_colors_config) + if config["state"] is None or config["state"].lower() == "info": + level = logging.INFO + elif config["state"].lower() == "debug": + level = logging.DEBUG + elif config["state"].lower() == "error": + level = logging.ERROR + elif config["state"].lower() == "warning": + level = logging.WARNING + elif config["state"].lower() == "critical": + level = logging.CRITICAL + else: + level = logging.INFO + + fh = logging.FileHandler(logfilepath) + fh.setLevel(level) + fh.setFormatter(fileformatter) + remove_color_filter = RemoveColorFilter() + fh.addFilter(remove_color_filter) + + sh = logging.StreamHandler() + sh.setLevel(level) + sh.setFormatter(sformatter) + + logging.basicConfig(level=level, handlers=[sh, fh]) \ No newline at end of file diff --git a/talkingface-toolkit-main/talkingface/utils/utils.py b/talkingface-toolkit-main/talkingface/utils/utils.py new file mode 100644 index 00000000..a5019491 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/utils.py @@ -0,0 +1,455 @@ +import datetime +import importlib +import os +import random +import pandas as pd + +import numpy as np +import torch +import torch.nn as nn + +from torch.utils.tensorboard import SummaryWriter +from texttable import Texttable + +def get_local_time(): + r"""Get current time + + Returns: + str: current time + """ + cur = datetime.datetime.now() + cur = cur.strftime("%b-%d-%Y_%H-%M-%S") + + return cur + + +def ensure_dir(dir_path): + r"""Make sure the directory exists, if it does not exist, create it + + Args: + dir_path (str): directory path + + """ + if not os.path.exists(dir_path): + os.makedirs(dir_path) + +def get_model(model_name): + r"""Automatically select model class based on model name + + Args: + model_name (str): model name + + Returns: + Recommender: model class + """ + model_submodule = [ + "audio_driven_talkingface", + "image_driven_talkingface", + "nerf_based_talkingface", + "text_to_speech", + "voice_conversion" + + ] + + model_file_name = model_name.lower() + model_module = None + for submodule in model_submodule: + module_path = ".".join(["talkingface.model", submodule, model_file_name]) + if importlib.util.find_spec(module_path, __name__): + model_module = importlib.import_module(module_path, __name__) + break + + if model_module is None: + raise ValueError( + "`model_name` [{}] is not the name of an existing model.".format(model_name) + ) + model_class = getattr(model_module, model_name) + return model_class + +def get_trainer(model_name): + r"""Automatically select trainer class based on model name + + Args: + model_name (str): model name + + Returns: + Trainer: trainer class + """ + try: + return getattr( + importlib.import_module("talkingface.trainer"), model_name + "Trainer" + ) + except AttributeError: + raise AttributeError( + "There is no trainer named `{}`".format(model_name) + ) + +def early_stopping(value, best, cur_step, max_step, bigger=False): + r"""validation-based early stopping + + Args: + value (float): current result + best (float): best result + cur_step (int): the number of consecutive steps that did not exceed the best result + max_step (int): threshold steps for stopping + bigger (bool, optional): whether the bigger the better + + Returns: + tuple: + - float, + best result after this step + - int, + the number of consecutive steps that did not exceed the best result after this step + - bool, + whether to stop + - bool, + whether to update + """ + stop_flag = False + update_flag = False + if bigger: + if value >= best: + cur_step = 0 + best = value + update_flag = True + else: + cur_step += 1 + if cur_step > max_step: + stop_flag = True + else: + if value <= best: + cur_step = 0 + best = value + update_flag = True + else: + cur_step += 1 + if cur_step > max_step: + stop_flag = True + return best, cur_step, stop_flag, update_flag + + +def calculate_valid_score(valid_result, valid_netric=None): + r"""Calculate the valid score according to the valid result and valid metric + + Args: + valid_result (dict): valid result + valid_metric (Metric or None, optional): valid metric + + Returns: + float: valid score + """ + if valid_netric is None: + return valid_result + else: + return valid_result[valid_netric] + +def dict2str(result_dict): + r"""convert result dict to str + + Args: + result_dict (dict): result dict + + Returns: + str: result str + """ + + return " ".join( + [str(metric) + " : " + str(value) for metric, value in result_dict.items()] + ) + + +def init_seed(seed, reproducibility): + r"""init random seed for random functions in numpy, torch, cuda and cudnn + + Args: + seed (int): random seed + reproducibility (bool): Whether to require reproducibility + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if reproducibility: + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + else: + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = False + +def get_tensorboard(logger): + r"""Creates a SummaryWriter of Tensorboard that can log PyTorch models and metrics into a directory for + visualization within the TensorBoard UI. + For the convenience of the user, the naming rule of the SummaryWriter's log_dir is the same as the logger. + + Args: + logger: its output filename is used to name the SummaryWriter's log_dir. + If the filename is not available, we will name the log_dir according to the current time. + + Returns: + SummaryWriter: it will write out events and summaries to the event file. + """ + base_path = "log_tensorboard" + + dir_name = None + for handler in logger.handlers: + if hasattr(handler, "baseFilename"): + dir_name = os.path.basename(getattr(handler, "baseFilename")).split(".")[0] + break + if dir_name is None: + dir_name = "{}-{}".format("model", get_local_time()) + + dir_path = os.path.join(base_path, dir_name) + writer = SummaryWriter(dir_path) + return writer + +def get_gpu_usage(device=None): + r"""Return the reserved memory and total memory of given device in a string. + Args: + device: cuda.device. It is the device that the model run on. + + Returns: + str: it contains the info about reserved memory and total memory of given device. + """ + + reserved = torch.cuda.max_memory_reserved(device) / 1024**3 + total = torch.cuda.get_device_properties(device).total_memory / 1024**3 + + return "{:.2f} G/{:.2f} G".format(reserved, total) + +def get_flops(model, dataset, device, logger, transform, verbose=False): + r"""Given a model and dataset to the model, compute the per-operator flops + of the given model. + Args: + model: the model to compute flop counts. + dataset: dataset that are passed to `model` to count flops. + device: cuda.device. It is the device that the model run on. + verbose: whether to print information of modules. + + Returns: + total_ops: the number of flops for each operation. + """ + # if model.type == ModelType.DECISIONTREE: + # return 1 + # if model.__class__.__name__ == "Pop": + # return 1 + + import copy + + model = copy.deepcopy(model) + + def count_normalization(m, x, y): + x = x[0] + flops = torch.DoubleTensor([2 * x.numel()]) + m.total_ops += flops + + def count_embedding(m, x, y): + x = x[0] + nelements = x.numel() + hiddensize = y.shape[-1] + m.total_ops += nelements * hiddensize + + class TracingAdapter(torch.nn.Module): + def __init__(self, rec_model): + super().__init__() + self.model = rec_model + + def forward(self, interaction): + return self.model.predict(interaction) + + custom_ops = { + torch.nn.Embedding: count_embedding, + torch.nn.LayerNorm: count_normalization, + } + wrapper = TracingAdapter(model) + inter = dataset[torch.tensor([1])].to(device) + inter = transform(dataset, inter) + inputs = (inter,) + from thop.profile import register_hooks + from thop.vision.basic_hooks import count_parameters + + handler_collection = {} + fn_handles = [] + params_handles = [] + types_collection = set() + if custom_ops is None: + custom_ops = {} + + def add_hooks(m: nn.Module): + m.register_buffer("total_ops", torch.zeros(1, dtype=torch.float64)) + m.register_buffer("total_params", torch.zeros(1, dtype=torch.float64)) + + m_type = type(m) + + fn = None + if m_type in custom_ops: + fn = custom_ops[m_type] + if m_type not in types_collection and verbose: + logger.info("Customize rule %s() %s." % (fn.__qualname__, m_type)) + elif m_type in register_hooks: + fn = register_hooks[m_type] + if m_type not in types_collection and verbose: + logger.info("Register %s() for %s." % (fn.__qualname__, m_type)) + else: + if m_type not in types_collection and verbose: + logger.warning( + "[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." + % m_type + ) + + if fn is not None: + handle_fn = m.register_forward_hook(fn) + handle_paras = m.register_forward_hook(count_parameters) + handler_collection[m] = ( + handle_fn, + handle_paras, + ) + fn_handles.append(handle_fn) + params_handles.append(handle_paras) + types_collection.add(m_type) + + prev_training_status = wrapper.training + + wrapper.eval() + wrapper.apply(add_hooks) + + with torch.no_grad(): + wrapper(*inputs) + + def dfs_count(module: nn.Module, prefix="\t"): + total_ops, total_params = module.total_ops.item(), 0 + ret_dict = {} + for n, m in module.named_children(): + next_dict = {} + if m in handler_collection and not isinstance( + m, (nn.Sequential, nn.ModuleList) + ): + m_ops, m_params = m.total_ops.item(), m.total_params.item() + else: + m_ops, m_params, next_dict = dfs_count(m, prefix=prefix + "\t") + ret_dict[n] = (m_ops, m_params, next_dict) + total_ops += m_ops + total_params += m_params + + return total_ops, total_params, ret_dict + + total_ops, total_params, ret_dict = dfs_count(wrapper) + + # reset wrapper to original status + wrapper.train(prev_training_status) + for m, (op_handler, params_handler) in handler_collection.items(): + m._buffers.pop("total_ops") + m._buffers.pop("total_params") + for i in range(len(fn_handles)): + fn_handles[i].remove() + params_handles[i].remove() + + return total_ops + +def list_to_latex(convert_list, bigger_flag=True, subset_columns=[]): + result = {} + for d in convert_list: + for key, value in d.items(): + if key in result: + result[key].append(value) + else: + result[key] = [value] + + df = pd.DataFrame.from_dict(result, orient="index").T + + if len(subset_columns) == 0: + tex = df.to_latex(index=False) + return df, tex + + def bold_func(x, bigger_flag): + if bigger_flag: + return np.where(x == np.max(x.to_numpy()), "font-weight:bold", None) + else: + return np.where(x == np.min(x.to_numpy()), "font-weight:bold", None) + + style = df.style + style.apply(bold_func, bigger_flag=bigger_flag, subset=subset_columns) + style.format(precision=4) + + num_column = len(df.columns) + column_format = "c" * num_column + tex = style.hide(axis="index").to_latex( + caption="Result Table", + label="Result Table", + convert_css=True, + hrules=True, + column_format=column_format, + ) + + return df, tex + +def get_environment(config): + gpu_usage = ( + get_gpu_usage(config["device"]) + if torch.cuda.is_available() and config["use_gpu"] + else "0.0 / 0.0" + ) + + import psutil + + memory_used = psutil.Process(os.getpid()).memory_info().rss / 1024**3 + memory_total = psutil.virtual_memory()[0] / 1024**3 + memory_usage = "{:.2f} G/{:.2f} G".format(memory_used, memory_total) + cpu_usage = "{:.2f} %".format(psutil.cpu_percent(interval=1)) + """environment_data = [ + {"Environment": "CPU", "Usage": cpu_usage,}, + {"Environment": "GPU", "Usage": gpu_usage, }, + {"Environment": "Memory", "Usage": memory_usage, }, + ]""" + + table = Texttable() + table.set_cols_align(["l", "c"]) + table.set_cols_valign(["m", "m"]) + table.add_rows( + [ + ["Environment", "Usage"], + ["CPU", cpu_usage], + ["GPU", gpu_usage], + ["Memory", memory_usage], + ] + ) + + return table + +def get_preprocess(dataset_name): + r"""Automatically select dataset preprocess class based on dataset name + """ + model_file_name = dataset_name.lower() + module_path = "talkingface.utils.data_process" + model_module = importlib.import_module(module_path, __name__) + try: + preprocess_class = getattr(model_module, dataset_name+'Preprocess') + except: + raise ValueError( + "`dataset_name` [{}] is not the name of an existing dataset.".format(dataset_name) + ) + return preprocess_class + +def create_dataset(config): + r"""Automatically select dataset class based on dataset name + """ + model_name = config['model'] + dataset_file_name = model_name.lower()+'_dataset' + module_path = ".".join(["talkingface.data.dataset", dataset_file_name]) + if importlib.util.find_spec(module_path, __name__): + dataset_module = importlib.import_module(module_path, __name__) + if dataset_module is None: + raise ValueError( + "`dataset_file_name` [{}] is not the name of an existing dataset.".format(dataset_file_name) + ) + dataset_class = getattr(dataset_module, model_name+'Dataset') + + return dataset_class(config, config['train_filelist']), dataset_class(config, config['val_filelist']) + + + + + + + + diff --git a/talkingface-toolkit-main/talkingface/utils/wandblogger.py b/talkingface-toolkit-main/talkingface/utils/wandblogger.py new file mode 100644 index 00000000..8d958474 --- /dev/null +++ b/talkingface-toolkit-main/talkingface/utils/wandblogger.py @@ -0,0 +1,57 @@ +class WandbLogger(object): + """WandbLogger to log metrics to Weights and Biases.""" + + def __init__(self, config): + """ + Args: + config (dict): A dictionary of parameters used by RecBole. + """ + self.config = config + self.log_wandb = config.log_wandb + self.setup() + + def setup(self): + if self.log_wandb: + try: + import wandb + + self._wandb = wandb + except ImportError: + raise ImportError( + "To use the Weights and Biases Logger please install wandb." + "Run `pip install wandb` to install it." + ) + + # Initialize a W&B run + if self._wandb.run is None: + self._wandb.init(project=self.config.wandb_project, config=self.config) + + self._set_steps() + + def log_metrics(self, metrics, head="train", commit=True): + if self.log_wandb: + if head: + metrics = self._add_head_to_metrics(metrics, head) + self._wandb.log(metrics, commit=commit) + else: + self._wandb.log(metrics, commit=commit) + + def log_eval_metrics(self, metrics, head="eval"): + if self.log_wandb: + metrics = self._add_head_to_metrics(metrics, head) + for k, v in metrics.items(): + self._wandb.run.summary[k] = v + + def _set_steps(self): + self._wandb.define_metric("train/*", step_metric="train_step") + self._wandb.define_metric("valid/*", step_metric="valid_step") + + def _add_head_to_metrics(self, metrics, head): + head_metrics = dict() + for k, v in metrics.items(): + if "_step" in k: + head_metrics[k] = v + else: + head_metrics[f"{head}/{k}"] = v + + return head_metrics