diff --git a/LiveSpeechPortraits/Paper_README.md b/LiveSpeechPortraits/Paper_README.md new file mode 100644 index 00000000..f0406a14 --- /dev/null +++ b/LiveSpeechPortraits/Paper_README.md @@ -0,0 +1,106 @@ +# Live Speech Portraits: Real-Time Photorealistic Talking-Head Animation + +This repository contains the implementation of the following paper: + +> **Live Speech Portraits: Real-Time Photorealistic Talking-Head Animation** +> +> Yuanxun Lu, [Jinxiang Chai](https://scholar.google.com/citations?user=OcN1_gwAAAAJ&hl=zh-CN&oi=ao), [Xun Cao](https://cite.nju.edu.cn/People/Faculty/20190621/i5054.html) *(SIGGRAPH Asia 2021)* +> +> **Abstract**: To the best of our knowledge, we first present a live system that generates personalized photorealistic talking-head animation only driven by audio signals at over 30 fps. Our system contains three stages. The first stage is a deep neural network that extracts deep audio features along with a manifold projection to project the features to the target person's speech space. In the second stage, we learn facial dynamics and motions from the projected audio features. The predicted motions include head poses and upper body motions, where the former is generated by an autoregressive probabilistic model which models the head pose distribution of the target person. Upper body motions are deduced from head poses. In the final stage, we generate conditional feature maps from previous predictions and send them with a candidate image set to an image-to-image translation network to synthesize photorealistic renderings. Our method generalizes well to wild audio and successfully synthesizes high-fidelity personalized facial details, e.g., wrinkles, teeth. Our method also allows explicit control of head poses. Extensive qualitative and quantitative evaluations, along with user studies, demonstrate the superiority of our method over state-of-the-art techniques. +> +> [[Project Page]](https://yuanxunlu.github.io/projects/LiveSpeechPortraits/) [[Paper]](https://yuanxunlu.github.io/projects/LiveSpeechPortraits/resources/SIGGRAPH_Asia_2021__Live_Speech_Portraits__Real_Time_Photorealistic_Talking_Head_Animation.pdf) [[Arxiv]](https://arxiv.org/abs/2109.10595) [[Web Demo]](https://replicate.ai/yuanxunlu/livespeechportraits) + +![Teaser](./doc/Teaser.jpg) + +Figure 1. Given an arbitrary input audio stream, our system generates personalized and photorealistic talking-head animation in real-time. Right: May and Obama are driven by the same utterance but present different speaking characteristics. + + + + +## Requirements + +- This project is successfully trained and tested on Windows10 with PyTorch 1.7 (Python 3.6). Linux and lower version PyTorch should also work (not tested). We recommend creating a new environment: + +``` +conda create -n LSP python=3.6 +conda activate LSP +``` + +- Clone the repository: + +``` +git clone https://github.com/YuanxunLu/LiveSpeechPortraits.git +cd LiveSpeechPortraits +``` + +- FFmpeg is required to combine the audio and the silent generated videos. Please check [FFmpeg](http://ffmpeg.org/download.html) for installation. For Linux users, you can also: + +``` +sudo apt-get install ffmpeg +``` + +- Install the dependences: + +``` +pip install -r requirements.txt +``` + + + +## Demo + +- Download the pre-trained models and data from [Google Drive](https://drive.google.com/drive/folders/1sHc2xEEGwnb0h2rkUhG9sPmOxvRvPVpJ?usp=sharing) to the `data` folder. Five subjects data are released (May, Obama1, Obama2, Nadella and McStay). + +- Run the demo: + + ``` + python demo.py --id May --driving_audio ./data/Input/00083.wav --device cuda + ``` + + Results can be found under the `results` folder. + +- **(New!) Docker and Web Demo** + + We are really grateful to [Andreas](https://github.com/andreasjansson) from [Replicate](https://replicate.ai/home) for his amazing job to make the web demo! Now you can run the [Demo](https://replicate.ai/yuanxunlu/livespeechportraits) on the browser. + +- **For the orginal links of these videos, please check issue [#7](https://github.com/YuanxunLu/LiveSpeechPortraits/issues/7).** + + + + +## Citation + +If you find this project useful for your research, please consider citing: + +``` +@article{lu2021live, + author = {Lu, Yuanxun and Chai, Jinxiang and Cao, Xun}, + title = {{Live Speech Portraits}: Real-Time Photorealistic Talking-Head Animation}, + journal = {ACM Transactions on Graphics}, + numpages = {17}, + volume={40}, + number={6}, + month = December, + year = {2021}, + doi={10.1145/3478513.3480484} +} +``` + + + +## Acknowledgment + +- This repo was built based on the framework of [pix2pix-pytorch](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). +- Thanks the authors of [MakeItTalk](https://github.com/adobe-research/MakeItTalk), [ATVG](https://github.com/lelechen63/ATVGnet), [RhythmicHead](https://github.com/lelechen63/Talking-head-Generation-with-Rhythmic-Head-Motion), [Speech-Driven Animation](https://github.com/DinoMan/speech-driven-animation) for making their excellent work and codes publicly available. +- Thanks [Andreas](https://github.com/andreasjansson) for the efforts of the web demo. + + + + + + + + + + + diff --git a/LiveSpeechPortraits/README.md b/LiveSpeechPortraits/README.md new file mode 100644 index 00000000..f708cc39 --- /dev/null +++ b/LiveSpeechPortraits/README.md @@ -0,0 +1,143 @@ +# 语音识别大作业 Live Speech Portraits + +## 小组成员 + +孙乐天,余博轩,孙仲天,李泽鸣 + +## 文件说明 + +`./source`:目录下包含全部源代码 + +`./data`:目录下包含 **部分** 测试用的音视频以及预训练模型 + +`README.md`:Docker 镜像说明文档 + +`Paper_README.md`:论文代码的说明文档 + +## LiveSpeechPortraits Docker 镜像使用说明 + +本项目的模型使用 `Docker 27.3` 版本封装了模型,使用前请确保已正确安装 `Docker Engine 27.3` 。您可以通过访问 [Docker 官网](https://www.docker.com/) 下载并安装 Docker。 + +论文配套代码使用 `pytorch-1.7.10+cu110` 环境,我们基于此版本进行了复现,Docker镜像可从此处下载:[lsp_demo_1.3.tar](https://pan.baidu.com/s/1NIJDdSzwFL3lPSb-tYuaZQ?pwd=fnbg) 。 + +由于使用的 `cuda` 版本较老,最新的RTX40系显卡无法进行训练。因此,面向高版本重新封装了基于 `cuda 11.8` 的Docker镜像。由于原文代码存在兼容性问题,因此我们替换了不兼容的模块,并改写了部分代码。该Docker镜像可从此处下载:[lsp_quickrun_cu118.tar](https://pan.baidu.com/s/1a1C2oy5DBqbVjWnXOOr9rw?pwd=imjb) 。 + +### 文件准备 + +#### 1. 从.tar文件载入 Docker 镜像 + +首先,参考上方链接,下载与您的 GPU 兼容的 Docker 镜像文件。 + +使用以下命令载入 Docker 镜像: + +```bash +docker load -i lsp_demo_XXXXX.tar +``` + +#### 2. 准备需要使用的数据 + +在运行 Docker 镜像前,请创建以下三个文件夹,分别用于存放预训练模型、输入音频和输出结果: + +`models`:存放预训练模型 + +`input`:存放输入音频文件 + +`results`:存放生成的输出结果 + +您可以通过 Google Drive [下载](https://drive.google.com/drive/folders/1sHc2xEEGwnb0h2rkUhG9sPmOxvRvPVpJ?usp=sharing) 预训练模型,并将其保存至 `models` 文件夹。确保文件夹中的内容如下所示: + +``` +. +|-- APC_epoch_160.model +|-- May +|-- McStay +|-- Nadella +|-- Obama1 +`-- Obama2 +``` + +将您需要输入模型的数据文件保存在 `input` 文件夹中 + +### 镜像使用参数说明 + +`LSP_QuickRun` 镜像支持两种运行模式:用于 **生成视频** 的 `--lspmodel` 模式和用于 **评估模型** 的 `--eval` 模式。 + +#### 1. `--lspmodel` 模式 + +在该模式下,您需要指定以下参数: + +`--id` : 预训练模型的名称,例如 `May` 、 `Obama1` 、 `Obama2` 等; + +`--device` : 所使用的设备类型,例如 `cuda` 、 `cpu` 等; + +`--driving_audio` : 输入音频文件的路径(Docker 容器内的路径)。 + +生成的视频文件将保存到容器内的 `/workspace/results` 目录中。 + +#### 2. `--eval` 模式 + +在此模式下,您需要指定以下参数: + +`--gt_video` : 参考视频的路径(Docker 容器内的路径); + +`--gen_video` : 模型生成的视频路径(Docker 容器内的路径)。 + +评估结果将在命令行中显示。 + +### Docker 运行命令示例 + +#### 1. `--lspmodel` 模式 + +运行命令的模板如下所示: + +```dockerfile +docker run -it --gpus all --rm --shm-size=8g \ +-v <本地的model文件夹目录>:/workspace/data \ +-v <本地的input文件夹目录>:/workspace/input \ +-v <本地的result文件夹目录>:/workspace/results \ +<镜像名称> \ +--lspmodel \ +--id <预训练模型ID> \ +--device <设备名称> \ +--driving_audio <容器内输入音频的路径> +``` + +例如: + +```dockerfile +docker run -it --gpus all --rm --shm-size=8g \ +-v E:\Code\Docker\models:/workspace/data \ +-v E:\Code\Docker\input:/workspace/input \ +-v E:\Code\Docker\results:/workspace/results \ +lsp_quickrun:1.3 \ +--lspmodel \ +--id Obama1 \ +--device cuda \ +--driving_audio /workspace/input/00083.wav +``` + +#### 2. `--eval` 模式 + +运行命令的模板如下所示: + +```dockerfile +docker run -it --gpus all --rm --shm-size=8g \ +-v <本地的input文件夹目录>:/workspace/input \ +-v <本地的result文件夹目录>:/workspace/results \ +<镜像名称> \ +--eval \ +--gt_video <容器内参考视频的路径> \ +--gen_video <容器内生成视频的路径> +``` + +例如: + +```dockerfile +docker run -it --gpus all --rm --shm-size=8g \ +-v E:\Code\Docker\input:/workspace/input \ +-v E:\Code\Docker\results:/workspace/results \ +lsp_quickrun:1.3 \ +--eval \ +--gt_video /workspace/input/May_short.mp4 \ +--gen_video /workspace/results/May/May_short/May_short.avi +``` \ No newline at end of file diff --git a/LiveSpeechPortraits/data/input/00083.wav b/LiveSpeechPortraits/data/input/00083.wav new file mode 100644 index 00000000..2bb527ff Binary files /dev/null and b/LiveSpeechPortraits/data/input/00083.wav differ diff --git a/LiveSpeechPortraits/data/input/May_short.mp4 b/LiveSpeechPortraits/data/input/May_short.mp4 new file mode 100644 index 00000000..247b6dd8 Binary files /dev/null and b/LiveSpeechPortraits/data/input/May_short.mp4 differ diff --git a/LiveSpeechPortraits/data/models/APC_epoch_160.model b/LiveSpeechPortraits/data/models/APC_epoch_160.model new file mode 100644 index 00000000..6572ec1d Binary files /dev/null and b/LiveSpeechPortraits/data/models/APC_epoch_160.model differ diff --git a/LiveSpeechPortraits/source_code/LICENSE b/LiveSpeechPortraits/source_code/LICENSE new file mode 100644 index 00000000..fd6831ed --- /dev/null +++ b/LiveSpeechPortraits/source_code/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 OldSix + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/LiveSpeechPortraits/source_code/cog.yaml b/LiveSpeechPortraits/source_code/cog.yaml new file mode 100644 index 00000000..04b96b32 --- /dev/null +++ b/LiveSpeechPortraits/source_code/cog.yaml @@ -0,0 +1,30 @@ +build: + gpu: true + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + - "libsox-fmt-mp3" + python_packages: + - "torch==1.7.1" + - "torchvision==0.8.2" + - "numpy==1.18.1" + - "ipython==7.21.0" + - "Pillow==8.3.1" + - "scikit-image==0.18.3" + - "librosa==0.7.2" + - "tqdm==4.62.3" + - "scipy==1.7.1" + - "dominate==2.6.0" + - "albumentations==0.5.2" + - "beautifulsoup4==4.10.0" + - "sox==1.4.1" + - "h5py==3.4.0" + - "numba==0.48" + - "moviepy==1.0.3" + run: + - apt update -y && apt-get install ffmpeg -y + - apt-get install sox libsox-fmt-mp3 -y + - pip install opencv-python==4.1.2.30 + +predict: "predict.py:Predictor" diff --git a/LiveSpeechPortraits/source_code/config/May.yaml b/LiveSpeechPortraits/source_code/config/May.yaml new file mode 100644 index 00000000..d28c4e76 --- /dev/null +++ b/LiveSpeechPortraits/source_code/config/May.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/May/checkpoints/Audio2Feature.pkl' + smooth: 1.5 + AMP: ['XYZ', 2, 2, 2] # method, x, y, z + Headpose: + ckp_path: './data/May/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [1, 0.5] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/May/checkpoints/Feature2Face.pkl' + size: 'large' + save_input: 1 + + +dataset_params: + root: './data/May/' + fit_data_path: './data/May/3d_fit_data.npz' + pts3d_path: './data/May/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/LiveSpeechPortraits/source_code/config/McStay.yaml b/LiveSpeechPortraits/source_code/config/McStay.yaml new file mode 100644 index 00000000..25d9db17 --- /dev/null +++ b/LiveSpeechPortraits/source_code/config/McStay.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/McStay/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/McStay/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/McStay/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/McStay/' + fit_data_path: './data/McStay/3d_fit_data.npz' + pts3d_path: './data/McStay/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/LiveSpeechPortraits/source_code/config/Nadella.yaml b/LiveSpeechPortraits/source_code/config/Nadella.yaml new file mode 100644 index 00000000..66f33573 --- /dev/null +++ b/LiveSpeechPortraits/source_code/config/Nadella.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/Nadella/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/Nadella/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [0.5, 0.5] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/Nadella/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/Nadella/' + fit_data_path: './data/Nadella/3d_fit_data.npz' + pts3d_path: './data/Nadella/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/LiveSpeechPortraits/source_code/config/Obama1.yaml b/LiveSpeechPortraits/source_code/config/Obama1.yaml new file mode 100644 index 00000000..ce414876 --- /dev/null +++ b/LiveSpeechPortraits/source_code/config/Obama1.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/Obama1/checkpoints/Audio2Feature.pkl' + smooth: 1 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/Obama1/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [2, 8] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/Obama1/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/Obama1/' + fit_data_path: './data/Obama1/3d_fit_data.npz' + pts3d_path: './data/Obama1/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/LiveSpeechPortraits/source_code/config/Obama2.yaml b/LiveSpeechPortraits/source_code/config/Obama2.yaml new file mode 100644 index 00000000..6d543151 --- /dev/null +++ b/LiveSpeechPortraits/source_code/config/Obama2.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/Obama2/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/Obama2/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [3, 10] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/Obama2/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/Obama2/' + fit_data_path: './data/Obama2/3d_fit_data.npz' + pts3d_path: './data/Obama2/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/LiveSpeechPortraits/source_code/datasets/__init__.py b/LiveSpeechPortraits/source_code/datasets/__init__.py new file mode 100644 index 00000000..dbe07be3 --- /dev/null +++ b/LiveSpeechPortraits/source_code/datasets/__init__.py @@ -0,0 +1,93 @@ +"""This package includes all the modules related to data loading and preprocessing + + To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. + You need to implement four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point from data loader. + -- : (optionally) add dataset-specific options and set default options. + +Now you can use the dataset class by specifying flag '--dataset_mode dummy'. +See our template dataset class 'template_dataset.py' for more details. +""" +import importlib +import torch.utils.data +from datasets.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + """Import the module "data/[dataset_name]_dataset.py". + + In the file, the class called DatasetNameDataset() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_filename = "datasets." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + """Return the static method of the dataset class.""" + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt): + """Create a dataset given the option. + + This function wraps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from data import create_dataset + >>> dataset = create_dataset(opt) + """ + data_loader = CustomDatasetDataLoader(opt) + dataset = data_loader.load_data() + return dataset + + +class CustomDatasetDataLoader(): + """Wrapper class of Dataset class that performs multi-threaded data loading""" + + def __init__(self, opt): + """Initialize this class + + Step 1: create a dataset instance given the name [dataset_mode] + Step 2: create a multi-threaded data loader. + """ + self.opt = opt + dataset_class = find_dataset_using_name(opt.dataset_mode) + self.dataset = dataset_class(opt) + print("dataset [%s] was created" % type(self.dataset).__name__) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size, + shuffle=not opt.serial_batches, + num_workers=int(opt.num_threads)) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data diff --git a/LiveSpeechPortraits/source_code/datasets/audiovisual_dataset.py b/LiveSpeechPortraits/source_code/datasets/audiovisual_dataset.py new file mode 100644 index 00000000..5db06c92 --- /dev/null +++ b/LiveSpeechPortraits/source_code/datasets/audiovisual_dataset.py @@ -0,0 +1,301 @@ +import sys +sys.path.append("..") + +from datasets.base_dataset import BaseDataset +import scipy.io as sio +import torch +import librosa +import bisect +import os +import numpy as np +from models.networks import APC_encoder + +from funcs import utils + + + +class AudioVisualDataset(BaseDataset): + """ audio-visual dataset. currently, return 2D info and 3D tracking info. + + # for wavenet: + # |----receptive_field----| + # |--output_length--| + # example: | | | | | | | | | | | | | | | | | | | | | + # target: | | | | | | | | | | + + """ + def __init__(self, opt): + # save the option and dataset root + BaseDataset.__init__(self, opt) + self.isTrain = self.opt.isTrain + self.state = opt.dataset_type + self.dataset_name = opt.dataset_names + self.target_length = opt.time_frame_length + self.sample_rate = opt.sample_rate + self.fps = opt.FPS + + self.audioRF_history = opt.audioRF_history + self.audioRF_future = opt.audioRF_future + self.compute_mel_online = opt.compute_mel_online + self.feature_name = opt.feature_name + + self.audio_samples_one_frame = self.sample_rate / self.fps + self.frame_jump_stride = opt.frame_jump_stride + self.augment = False + self.task = opt.task + self.item_length_audio = int((self.audioRF_history + self.audioRF_future)/ self.fps * self.sample_rate) + + if self.task == 'Audio2Feature': + if opt.feature_decoder == 'WaveNet': + self.A2L_receptive_field = opt.A2L_receptive_field + self.A2L_item_length = self.A2L_receptive_field + self.target_length - 1 + elif opt.feature_decoder == 'LSTM': + self.A2L_receptive_field = 30 + self.A2L_item_length = self.A2L_receptive_field + self.target_length - 1 + elif self.task == 'Audio2Headpose': + self.A2H_receptive_field = opt.A2H_receptive_field + self.A2H_item_length = self.A2H_receptive_field + self.target_length - 1 + self.audio_window = opt.audio_windows + self.half_audio_win = int(self.audio_window / 2) + + self.frame_future = opt.frame_future + self.predict_length = opt.predict_length + self.predict_len = int((self.predict_length - 1) / 2) + + self.gpu_ids = opt.gpu_ids + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + print('self.device:', self.device) + if self.task == 'Audio2Feature': + self.seq_len = opt.sequence_length + + self.total_len = 0 + self.dataset_root = os.path.join(self.root, self.dataset_name) + if self.state == 'Train': + self.clip_names = opt.train_dataset_names + elif self.state == 'Val': + self.clip_names = opt.validate_dataset_names + elif self.state == 'Test': + self.clip_names = opt.test_dataset_names + + + self.clip_nums = len(self.clip_names) + # main info + self.audio = [''] * self.clip_nums + self.audio_features = [''] * self.clip_nums + self.feats = [''] * self.clip_nums + self.exps = [''] * self.clip_nums + self.pts3d = [''] * self.clip_nums + self.rot_angles = [''] * self.clip_nums + self.trans = [''] * self.clip_nums + self.headposes = [''] * self.clip_nums + self.velocity_pose = [''] * self.clip_nums + self.acceleration_pose = [''] * self.clip_nums + self.mean_trans = [''] * self.clip_nums + if self.state == 'Test': + self.landmarks = [''] * self.clip_nums + # meta info + self.start_point = [''] * self.clip_nums + self.end_point = [''] * self.clip_nums + self.len = [''] * self.clip_nums + self.sample_start = [] + self.clip_valid = ['True'] * self.clip_nums + self.invalid_clip = [] + + + self.mouth_related_indices = np.concatenate([np.arange(4, 11), np.arange(46, 64)]) + if self.task == 'Audio2Feature': + if self.opt.only_mouth: + self.indices = self.mouth_related_indices + else: + self.indices = np.arange(73) + if opt.use_delta_pts: + self.pts3d_mean = np.load(os.path.join(self.dataset_root, 'mean_pts3d.npy')) + + for i in range(self.clip_nums): + name = self.clip_names[i] + clip_root = os.path.join(self.dataset_root, name) + # audio + if os.path.exists(os.path.join(clip_root, name + '_denoise.wav')): + audio_path = os.path.join(clip_root, name + '_denoise.wav') + print('find denoised wav!') + else: + audio_path = os.path.join(clip_root, name + '.wav') + self.audio[i], _ = librosa.load(audio_path, sr=self.sample_rate) + + if self.opt.audio_encoder == 'APC': + APC_name = os.path.split(self.opt.APC_model_path)[-1] + APC_feature_file = name + '_APC_feature_V0324_ckpt_{}.npy'.format(APC_name) + APC_feature_path = os.path.join(clip_root, APC_feature_file) + need_deepfeats = False if os.path.exists(APC_feature_path) else True + if not need_deepfeats: + self.audio_features[i] = np.load(APC_feature_path).astype(np.float32) + else: + need_deepfeats = False + + + # 3D landmarks & headposes + if self.task == 'Audio2Feature': + self.start_point[i] = 0 + elif self.task == 'Audio2Headpose': + self.start_point[i] = 300 + fit_data_path = os.path.join(clip_root, '3d_fit_data.npz') + fit_data = np.load(fit_data_path) + if not opt.ispts_norm: + ori_pts3d = fit_data['pts_3d'].astype(np.float32) + else: + ori_pts3d = np.load(os.path.join(clip_root, 'tracked3D_normalized_pts_fix_contour.npy')) + if opt.use_delta_pts: + self.pts3d[i] = ori_pts3d - self.pts3d_mean + else: + self.pts3d[i] = ori_pts3d + if opt.feature_dtype == 'pts3d': + self.feats[i] = self.pts3d[i] + elif opt.feature_dtype == 'FW': + track_data_path = os.path.join(clip_root, 'tracking_results.mat') + self.feats[i] = sio.loadmat(track_data_path)['exps'].astype(np.float32) + self.rot_angles[i] = fit_data['rot_angles'].astype(np.float32) + # change -180~180 to 0~360 + if not self.dataset_name == 'Yuxuan': + rot_change = self.rot_angles[i][:, 0] < 0 + self.rot_angles[i][rot_change, 0] += 360 + self.rot_angles[i][:,0] -= 180 # change x axis direction + # use delta translation + self.mean_trans[i] = fit_data['trans'][:,:,0].astype(np.float32).mean(axis=0) + self.trans[i] = fit_data['trans'][:,:,0].astype(np.float32) - self.mean_trans[i] + + self.headposes[i] = np.concatenate([self.rot_angles[i], self.trans[i]], axis=1) + self.velocity_pose[i] = np.concatenate([np.zeros(6)[None,:], self.headposes[i][1:] - self.headposes[i][:-1]]) + self.acceleration_pose[i] = np.concatenate([np.zeros(6)[None,:], self.velocity_pose[i][1:] - self.velocity_pose[i][:-1]]) + + if self.dataset_name == 'Yuxuan': + total_frames = self.feats[i].shape[0] - 300 - 130 + else: + total_frames = self.feats[i].shape[0] - 60 + + + if need_deepfeats: + if self.opt.audio_encoder == 'APC': + print('dataset {} need to pre-compute APC features ...'.format(name)) + print('first we compute mel spectram for dataset {} '.format(name)) + mel80 = utils.compute_mel_one_sequence(self.audio[i]) + mel_nframe = mel80.shape[0] + print('loading pre-trained model: ', self.opt.APC_model_path) + APC_model = APC_encoder(self.opt.audiofeature_input_channels, + self.opt.APC_hidden_size, + self.opt.APC_rnn_layers, + self.opt.APC_residual) + APC_model.load_state_dict(torch.load(self.opt.APC_model_path, map_location=str(self.device)), strict=False) +# APC_model.load_state_dict(torch.load(self.opt.APC_model_path), strict=False) + APC_model.cuda() + APC_model.eval() + with torch.no_grad(): + length = torch.Tensor([mel_nframe]) +# hidden_reps = torch.zeros([mel_nframe, self.opt.APC_hidden_size]).cuda() + mel80_torch = torch.from_numpy(mel80.astype(np.float32)).cuda().unsqueeze(0) + hidden_reps = APC_model.forward(mel80_torch, length)[0] # [mel_nframe, 512] + hidden_reps = hidden_reps.cpu().numpy() + np.save(APC_feature_path, hidden_reps) + self.audio_features[i] = hidden_reps + + + valid_frames = total_frames - self.start_point[i] + self.len[i] = valid_frames - 400 + if i == 0: + self.sample_start.append(0) + else: + self.sample_start.append(self.sample_start[-1] + self.len[i-1] - 1) + self.total_len += np.int32(np.floor(self.len[i] / self.frame_jump_stride)) + + + + def __getitem__(self, index): + # recover real index from compressed one + index_real = np.int32(index * self.frame_jump_stride) + # find which audio file and the start frame index + file_index = bisect.bisect_right(self.sample_start, index_real) - 1 + current_frame = index_real - self.sample_start[file_index] + self.start_point[file_index] + current_target_length = self.target_length + + if self.task == 'Audio2Feature': + # start point is current frame + A2Lsamples = self.audio_features[file_index][current_frame * 2 : (current_frame + self.seq_len) * 2] + target_pts3d = self.feats[file_index][current_frame : current_frame + self.seq_len, self.indices].reshape(self.seq_len, -1) + + A2Lsamples = torch.from_numpy(A2Lsamples).float() + target_pts3d = torch.from_numpy(target_pts3d).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Lsamples, target_pts3d + + + elif self.task == 'Audio2Headpose': + if self.opt.feature_decoder == 'WaveNet': + # find the history info start points + A2H_history_start = current_frame - self.A2H_receptive_field + A2H_item_length = self.A2H_item_length + A2H_receptive_field = self.A2H_receptive_field + + if self.half_audio_win == 1: + A2Hsamples = self.audio_features[file_index][2 * (A2H_history_start + self.frame_future) : 2 * (A2H_history_start + self.frame_future + A2H_item_length)] + else: + A2Hsamples = np.zeros([A2H_item_length, self.audio_window, 512]) + for i in range(A2H_item_length): + A2Hsamples[i] = self.audio_features[file_index][2 * (A2H_history_start + i) - self.half_audio_win : 2 * (A2H_history_start + i) + self.half_audio_win] + + if self.predict_len == 0: + target_headpose = self.headposes[file_index][A2H_history_start + A2H_receptive_field : A2H_history_start + A2H_item_length + 1] + history_headpose = self.headposes[file_index][A2H_history_start : A2H_history_start + A2H_item_length].reshape(A2H_item_length, -1) + + target_velocity = self.velocity_pose[file_index][A2H_history_start + A2H_receptive_field : A2H_history_start + A2H_item_length + 1] + history_velocity = self.velocity_pose[file_index][A2H_history_start : A2H_history_start + A2H_item_length].reshape(A2H_item_length, -1) + target_info = torch.from_numpy(np.concatenate([target_headpose, target_velocity], axis=1).reshape(current_target_length, -1)).float() + else: + history_headpose = self.headposes[file_index][A2H_history_start : A2H_history_start + A2H_item_length].reshape(A2H_item_length, -1) + history_velocity = self.velocity_pose[file_index][A2H_history_start : A2H_history_start + A2H_item_length].reshape(A2H_item_length, -1) + + + target_headpose_ = self.headposes[file_index][A2H_history_start + A2H_receptive_field - self.predict_len : A2H_history_start + A2H_item_length + 1 + self.predict_len + 1] + target_headpose = np.zeros([current_target_length, self.predict_length, target_headpose_.shape[1]]) + for i in range(current_target_length): + target_headpose[i] = target_headpose_[i : i + self.predict_length] + target_headpose = target_headpose#.reshape(current_target_length, -1, order='F') + + target_velocity_ = self.headposes[file_index][A2H_history_start + A2H_receptive_field - self.predict_len : A2H_history_start + A2H_item_length + 1 + self.predict_len + 1] + target_velocity = np.zeros([current_target_length, self.predict_length, target_velocity_.shape[1]]) + for i in range(current_target_length): + target_velocity[i] = target_velocity_[i : i + self.predict_length] + target_velocity = target_velocity#.reshape(current_target_length, -1, order='F') + + target_info = torch.from_numpy(np.concatenate([target_headpose, target_velocity], axis=2).reshape(current_target_length, -1)).float() + + A2Hsamples = torch.from_numpy(A2Hsamples).float() + + + history_info = torch.from_numpy(np.concatenate([history_headpose, history_velocity], axis=1)).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Hsamples, history_info, target_info + + + elif self.opt.feature_decoder == 'LSTM': + A2Hsamples = self.audio_features[file_index][current_frame * 2 : (current_frame + self.opt.A2H_receptive_field) * 2] + + target_headpose = self.headposes[file_index][current_frame : current_frame + self.opt.A2H_receptive_field] + target_velocity = self.velocity_pose[file_index][current_frame : current_frame + self.opt.A2H_receptive_field] + target_info = torch.from_numpy(np.concatenate([target_headpose, target_velocity], axis=1).reshape(self.opt.A2H_receptive_field, -1)).float() + + A2Hsamples = torch.from_numpy(A2Hsamples).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Hsamples, target_info + + + + def __len__(self): + return self.total_len + + + + + + diff --git a/LiveSpeechPortraits/source_code/datasets/base_dataset.py b/LiveSpeechPortraits/source_code/datasets/base_dataset.py new file mode 100644 index 00000000..193d19b5 --- /dev/null +++ b/LiveSpeechPortraits/source_code/datasets/base_dataset.py @@ -0,0 +1,64 @@ +"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" + +import numpy as np +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +from abc import ABC, abstractmethod + + +class BaseDataset(data.Dataset, ABC): + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + self.root = opt.dataroot + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + + + diff --git a/LiveSpeechPortraits/source_code/datasets/face_dataset.py b/LiveSpeechPortraits/source_code/datasets/face_dataset.py new file mode 100644 index 00000000..fa0e4b52 --- /dev/null +++ b/LiveSpeechPortraits/source_code/datasets/face_dataset.py @@ -0,0 +1,375 @@ +import os +from datasets.base_dataset import BaseDataset +import os.path +from pathlib import Path +import torch +from skimage.io import imread, imsave +from PIL import Image +import bisect +import numpy as np +import io +import cv2 +import h5py +import albumentations as A + + +class FaceDataset(BaseDataset): + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + + A few things can be done here. + - save the options (have been done in BaseDataset) + - get image paths and meta information of the dataset. + - define the image transformation. + """ + BaseDataset.__init__(self, opt) + self.state = 'Train' if self.opt.isTrain else 'Test' + self.dataset_name = opt.dataset_names[0] + + # default settings + # currently, we have 8 parts for face parts + self.part_list = [[list(range(0, 15))], # contour + [[15,16,17,18,18,19,20,15]], # right eyebrow + [[21,22,23,24,24,25,26,21]], # left eyebrow + [range(35, 44)], # nose + [[27,65,28,68,29], [29,67,30,66,27]], # right eye + [[33,69,32,72,31], [31,71,34,70,33]], # left eye + [range(46, 53), [52,53,54,55,56,57,46]], # mouth + [[46,63,62,61,52], [52,60,59,58,46]] # tongue + ] + self.mouth_outer = [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 46] + self.label_list = [1, 1, 2, 3, 3, 4, 5] # labeling for different facial parts + + # only load in train mode + + self.dataset_root = os.path.join(self.root, self.dataset_name) + if self.state == 'Train': + self.clip_names = opt.train_dataset_names + elif self.state == 'Val': + self.clip_names = opt.validate_dataset_names + elif self.state == 'Test': + self.clip_names = opt.test_dataset_names + + self.clip_nums = len(self.clip_names) + + # load pts & image info + self.landmarks2D, self.len, self.sample_len = [''] * self.clip_nums, [''] * self.clip_nums, [''] * self.clip_nums + self.image_transforms, self.image_pad, self.tgts_paths = [''] * self.clip_nums, [''] * self.clip_nums, [''] * self.clip_nums + self.shoulders, self.shoulder3D = [''] * self.clip_nums, [''] * self.clip_nums + self.sample_start = [] + + # tracked 3d info & candidates images + self.pts3d, self.rot, self.trans = [''] * self.clip_nums, [''] * self.clip_nums, [''] * self.clip_nums + self.full_cand = [''] * self.clip_nums + self.headposes = [''] * self.clip_nums + + self.total_len = 0 + if self.opt.isTrain: + for i in range(self.clip_nums): + name = self.clip_names[i] + clip_root = os.path.join(self.dataset_root, name) + # basic image info + img_file_path = os.path.join(clip_root, name + '.h5') + img_file = h5py.File(img_file_path, 'r')[name] + example = np.asarray(Image.open(io.BytesIO(img_file[0]))) + h, w, _ = example.shape + + + landmark_path = os.path.join(clip_root, 'tracked2D_normalized_pts_fix_contour.npy') + self.landmarks2D[i] = np.load(landmark_path).astype(np.float32) + change_paras = np.load(os.path.join(clip_root, 'change_paras.npz')) + scale, xc, yc = change_paras['scale'], change_paras['xc'], change_paras['yc'] + x_min, x_max, y_min, y_max = xc-256, xc+256, yc-256, yc+256 + # if need padding + x_min, x_max, y_min, y_max, self.image_pad[i] = max(x_min, 0), min(x_max, w), max(y_min, 0), min(y_max, h), None + + + if x_min == 0 or x_max == 512 or y_min == 0 or y_max == 512: + top, bottom, left, right = abs(yc-256-y_min), abs(yc+256-y_max), abs(xc-256-x_min), abs(xc+256-x_max) + self.image_pad[i] = [top, bottom, left, right] + self.image_transforms[i] = A.Compose([ + A.Resize(np.int32(h*scale), np.int32(w*scale)), + A.Crop(x_min, y_min, x_max, y_max)]) + + if self.opt.isH5: + tgt_file_path = os.path.join(clip_root, name + '.h5') + tgt_file = h5py.File(tgt_file_path, 'r')[name] + image_length = len(tgt_file) + else: + tgt_paths = list(map(lambda x:str(x), sorted(list(Path(clip_root).glob('*'+self.opt.suffix)), key=lambda x: int(x.stem)))) + image_length = len(tgt_paths) + self.tgts_paths[i] = tgt_paths + if not self.landmarks2D[i].shape[0] == image_length: + raise ValueError('In dataset {} length of landmarks and images are not equal!'.format(name)) + + # tracked 3d info + fit_data_path = os.path.join(clip_root, '3d_fit_data.npz') + fit_data = np.load(fit_data_path) + self.pts3d[i] = fit_data['pts_3d'].astype(np.float32) + self.rot[i] = fit_data['rot_angles'].astype(np.float32) + self.trans[i] = fit_data['trans'][:,:,0].astype(np.float32) + if not self.pts3d[i].shape[0] == image_length: + raise ValueError('In dataset {} length of 3d pts and images are not equal!'.format(name)) + + # candidates images + + tmp = [] + for j in range(4): + try: + output = imread(os.path.join(clip_root, 'candidates', f'normalized_full_{j}.jpg')) + except: + imgc = imread(os.path.join(clip_root, 'candidates', f'full_{j}.jpg')) + output = self.common_dataset_transform(imgc, i) + imsave(os.path.join(clip_root, 'candidates', f'normalized_full_{j}.jpg'), output) + output = A.pytorch.transforms.ToTensor(normalize={'mean':(0.5,0.5,0.5), 'std':(0.5,0.5,0.5)})(image=output)['image'] + tmp.append(output) + self.full_cand[i] = torch.cat(tmp) + + # headpose + fit_data_path = os.path.join(clip_root, '3d_fit_data.npz') + fit_data = np.load(fit_data_path) + rot_angles = fit_data['rot_angles'].astype(np.float32) + # change -180~180 to 0~360 + if not self.dataset_name == 'Yuxuan': + rot_change = rot_angles[:, 0] < 0 + rot_angles[rot_change, 0] += 360 + rot_angles[:,0] -= 180 # change x axis direction + # use delta translation + mean_trans = fit_data['trans'][:,:,0].astype(np.float32).mean(axis=0) + trans = fit_data['trans'][:,:,0].astype(np.float32) - mean_trans + + self.headposes[i] = np.concatenate([rot_angles, trans], axis=1) + + # shoulders + shoulder_path = os.path.join(clip_root, 'normalized_shoulder_points.npy') + self.shoulders[i] = np.load(shoulder_path) + shoulder3D_path = os.path.join(clip_root, 'shoulder_points3D.npy') + self.shoulder3D[i] = np.load(shoulder3D_path) + + + self.sample_len[i] = np.int32(np.floor((self.landmarks2D[i].shape[0] - 60) / self.opt.frame_jump) + 1) + self.len[i] = self.landmarks2D[i].shape[0] + if i == 0: + self.sample_start.append(0) + else: + self.sample_start.append(self.sample_start[-1] + self.sample_len[i-1]) # not minus 1 + self.total_len += self.sample_len[i] + + # test mode + else: + # if need padding + example = imread(os.path.join(self.root, 'example.png')) + h, w, _ = example.shape + change_paras = np.load(os.path.join(self.root, 'change_paras.npz')) + scale, xc, yc = change_paras['scale'], change_paras['xc'], change_paras['yc'] + x_min, x_max, y_min, y_max = xc-256, xc+256, yc-256, yc+256 + x_min, x_max, y_min, y_max, self.image_pad = max(x_min, 0), min(x_max, w), max(y_min, 0), min(y_max, h), None + + + if x_min == 0 or x_max == 512 or y_min == 0 or y_max == 512: + top, bottom, left, right = abs(yc-256-y_min), abs(yc+256-y_max), abs(xc-256-x_min), abs(xc+256-x_max) + self.image_pad = [top, bottom, left, right] + + + + + + def __getitem__(self, ind): + dataset_index = bisect.bisect_right(self.sample_start, ind) - 1 + data_index = (ind - self.sample_start[dataset_index]) * self.opt.frame_jump + np.random.randint(self.opt.frame_jump) + + target_ind = data_index + 1 # history_ind, current_ind + landmarks = self.landmarks2D[dataset_index][target_ind] # [73, 2] + shoulders = self.shoulders[dataset_index][target_ind].copy() + + dataset_name = self.clip_names[dataset_index] + clip_root = os.path.join(self.dataset_root, dataset_name) + if self.opt.isH5: + tgt_file_path = os.path.join(clip_root, dataset_name + '.h5') + tgt_file = h5py.File(tgt_file_path, 'r')[dataset_name] + tgt_image = np.asarray(Image.open(io.BytesIO(tgt_file[target_ind]))) + + # do transform + tgt_image = self.common_dataset_transform(tgt_image, dataset_index, None) + else: + pass + + h, w, _ = tgt_image.shape + + ### transformations & online data augmentations on images and landmarks + self.get_crop_coords(landmarks, (w, h), dataset_name, random_trans_scale=0) # 30.5 µs ± 348 ns random translation + + transform_tgt = self.get_transform(dataset_name, True, n_img=1, n_keypoint=1, flip=False) + transformed_tgt = transform_tgt(image=tgt_image, keypoints=landmarks) + + tgt_image, points = transformed_tgt['image'], np.array(transformed_tgt['keypoints']).astype(np.float32) + + feature_map = self.get_feature_image(points, (self.opt.loadSize, self.opt.loadSize), shoulders, self.image_pad[dataset_index])[np.newaxis, :].astype(np.float32)/255. + feature_map = torch.from_numpy(feature_map) + + ## facial weight mask + weight_mask = self.generate_facial_weight_mask(points, h, w)[None, :] + + cand_image = self.full_cand[dataset_index] + + return_list = {'feature_map': feature_map, 'cand_image': cand_image, 'tgt_image': tgt_image, 'weight_mask': weight_mask} + + return return_list + + + + + def common_dataset_transform(self, input, i): + output = self.image_transforms[i](image=input)['image'] + if self.image_pad[i] is not None: + top, bottom, left, right = self.image_pad[i] + output = cv2.copyMakeBorder(output, top, bottom, left, right, cv2.BORDER_CONSTANT, value = 0) + return output + + + + def generate_facial_weight_mask(self, points, h = 512, w = 512): + mouth_mask = np.zeros([512, 512, 1]) + points = points[self.mouth_outer] + points = np.int32(points) + mouth_mask = cv2.fillPoly(mouth_mask, [points], (255,0,0)) +# plt.imshow(mouth_mask[:,:,0]) + mouth_mask = cv2.dilate(mouth_mask, np.ones((45, 45))) / 255 + + return mouth_mask.astype(np.float32) + + + + def get_transform(self, dataset_name, keypoints=False, n_img=1, n_keypoint=1, normalize=True, flip=False): + min_x = getattr(self, 'min_x_' + str(dataset_name)) + max_x = getattr(self, 'max_x_' + str(dataset_name)) + min_y = getattr(self, 'min_y_' + str(dataset_name)) + max_y = getattr(self, 'max_y_' + str(dataset_name)) + + additional_flag = False + additional_targets_dict = {} + if n_img > 1: + additional_flag = True + image_str = ['image' + str(i) for i in range(0, n_img)] + for i in range(n_img): + additional_targets_dict[image_str[i]] = 'image' + if n_keypoint > 1: + additional_flag = True + keypoint_str = ['keypoint' + str(i) for i in range(0, n_keypoint)] + for i in range(n_keypoint): + additional_targets_dict[keypoint_str[i]] = 'keypoints' + + transform = A.Compose([ + A.Crop(x_min=min_x, x_max=max_x, y_min=min_y, y_max=max_y), + A.Resize(self.opt.loadSize, self.opt.loadSize), + A.HorizontalFlip(p=flip), + A.pytorch.transforms.ToTensor(normalize={'mean':(0.5,0.5,0.5), 'std':(0.5,0.5,0.5)} if normalize==True else None)], + keypoint_params=A.KeypointParams(format='xy', remove_invisible=False) if keypoints==True else None, + additional_targets=additional_targets_dict if additional_flag == True else None + ) + return transform + + + def get_data_test_mode(self, landmarks, shoulder, pad=None): + ''' get transformed data + ''' + + feature_map = torch.from_numpy(self.get_feature_image(landmarks, (self.opt.loadSize, self.opt.loadSize), shoulder, pad)[np.newaxis, :].astype(np.float32)/255.) + + return feature_map + + + def get_feature_image(self, landmarks, size, shoulders=None, image_pad=None): + # draw edges + im_edges = self.draw_face_feature_maps(landmarks, size) + if shoulders is not None: + if image_pad is not None: + top, bottom, left, right = image_pad + delta_y = top - bottom + delta_x = right - left + shoulders[:, 0] += delta_x + shoulders[:, 1] += delta_y + im_edges = self.draw_shoulder_points(im_edges, shoulders) + + + return im_edges + + + def draw_shoulder_points(self, img, shoulder_points): + num = int(shoulder_points.shape[0] / 2) + for i in range(2): + for j in range(num - 1): + pt1 = [int(flt) for flt in shoulder_points[i * num + j]] + pt2 = [int(flt) for flt in shoulder_points[i * num + j + 1]] + img = cv2.line(img, tuple(pt1), tuple(pt2), 255, 2) # BGR + + return img + + + def draw_face_feature_maps(self, keypoints, size=(512, 512)): + w, h = size + # edge map for face region from keypoints + im_edges = np.zeros((h, w), np.uint8) # edge map for all edges + for edge_list in self.part_list: + for edge in edge_list: + for i in range(len(edge)-1): + pt1 = [int(flt) for flt in keypoints[edge[i]]] + pt2 = [int(flt) for flt in keypoints[edge[i + 1]]] + im_edges = cv2.line(im_edges, tuple(pt1), tuple(pt2), 255, 2) + + return im_edges + + + def get_crop_coords(self, keypoints, size, dataset_name, random_trans_scale=50): + # cut a rought region for fine cutting + # here x towards right and y towards down, origin is left-up corner + w_ori, h_ori = size + min_y, max_y = keypoints[:,1].min(), keypoints[:,1].max() + min_x, max_x = keypoints[:,0].min(), keypoints[:,0].max() + xc = (min_x + max_x) // 2 + yc = (min_y*3 + max_y) // 4 + h = w = min((max_x - min_x) * 2, w_ori, h_ori) + + if self.opt.isTrain: + # do online augment on landmarks & images + # 1. random translation: move 10% + x_bias, y_bias = np.random.uniform(-random_trans_scale, random_trans_scale, size=(2,)) + xc, yc = xc + x_bias, yc + y_bias + + # modify the center x, center y to valid position + xc = min(max(0, xc - w//2) + w, w_ori) - w//2 + yc = min(max(0, yc - h//2) + h, h_ori) - h//2 + + min_x, max_x = xc - w//2, xc + w//2 + min_y, max_y = yc - h//2, yc + h//2 + + setattr(self, 'min_x_' + str(dataset_name), int(min_x)) + setattr(self, 'max_x_' + str(dataset_name), int(max_x)) + setattr(self, 'min_y_' + str(dataset_name), int(min_y)) + setattr(self, 'max_y_' + str(dataset_name), int(max_y)) + + + def crop(self, img, dataset_name): + min_x = getattr(self, 'min_x_' + str(dataset_name)) + max_x = getattr(self, 'max_x_' + str(dataset_name)) + min_y = getattr(self, 'min_y_' + str(dataset_name)) + max_y = getattr(self, 'max_y_' + str(dataset_name)) + if isinstance(img, np.ndarray): + return img[min_y:max_y, min_x:max_x] + else: + return img.crop((min_x, min_y, max_x, max_y)) + + + def __len__(self): + if self.opt.isTrain: + return self.total_len + else: + return 1 + + def name(self): + return 'FaceDataset' + + diff --git a/LiveSpeechPortraits/source_code/demo.py b/LiveSpeechPortraits/source_code/demo.py new file mode 100644 index 00000000..0143af58 --- /dev/null +++ b/LiveSpeechPortraits/source_code/demo.py @@ -0,0 +1,307 @@ +import os +import subprocess +from os.path import join +from tqdm import tqdm +import numpy as np +import torch +from collections import OrderedDict +import librosa +from skimage.io import imread +import cv2 +import scipy.io as sio +import argparse +import yaml +import albumentations as A +import albumentations.pytorch +from pathlib import Path + +from options.test_audio2feature_options import TestOptions as FeatureOptions +from options.test_audio2headpose_options import TestOptions as HeadposeOptions +from options.test_feature2face_options import TestOptions as RenderOptions + +from datasets import create_dataset +from models import create_model +from models.networks import APC_encoder +import util.util as util +from util.visualizer import Visualizer +from funcs import utils +from funcs import audio_funcs + +import warnings +warnings.filterwarnings("ignore") + + + +def write_video_with_audio(audio_path, output_path, prefix='pred_'): + fps, fourcc = 60, cv2.VideoWriter_fourcc(*'DIVX') + video_tmp_path = join(save_root, 'tmp.avi') + out = cv2.VideoWriter(video_tmp_path, fourcc, fps, (Renderopt.loadSize, Renderopt.loadSize)) + for j in tqdm(range(nframe), position=0, desc='writing video'): + img = cv2.imread(join(save_root, prefix + str(j+1) + '.jpg')) + out.write(img) + out.release() + cmd = 'ffmpeg -i "' + video_tmp_path + '" -i "' + audio_path + '" -codec copy -shortest "' + output_path + '"' + subprocess.call(cmd, shell=True) + os.remove(video_tmp_path) # remove the template video + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--id', default='May', help="person name, e.g. Obama1, Obama2, May, Nadella, McStay") + parser.add_argument('--driving_audio', default='./data/input/00083.wav', help="path to driving audio") + parser.add_argument('--save_intermediates', default=0, help="whether to save intermediate results") + parser.add_argument('--device', type=str, default='cpu', help='use cuda for GPU or use cpu for CPU') + + + ############################### I/O Settings ############################## + # load config files + opt = parser.parse_args() + device = torch.device(opt.device) + + with open(join('./config/', opt.id + '.yaml')) as f: + config = yaml.load(f, Loader=yaml.SafeLoader) + data_root = join('./data/', opt.id) + # create the results folder + audio_name = os.path.split(opt.driving_audio)[1][:-4] + save_root = join('./results/', opt.id, audio_name) + if not os.path.exists(save_root): + os.makedirs(save_root) + + + + ############################ Hyper Parameters ############################# + h, w, sr, FPS = 512, 512, 16000, 60 + mouth_indices = np.concatenate([np.arange(4, 11), np.arange(46, 64)]) + eye_brow_indices = [27, 65, 28, 68, 29, 67, 30, 66, 31, 72, 32, 69, 33, 70, 34, 71] + eye_brow_indices = np.array(eye_brow_indices, np.int32) + + + + ############################ Pre-defined Data ############################# + mean_pts3d = np.load(join(data_root, 'mean_pts3d.npy')) + fit_data = np.load(config['dataset_params']['fit_data_path']) + pts3d = np.load(config['dataset_params']['pts3d_path']) - mean_pts3d + trans = fit_data['trans'][:,:,0].astype(np.float32) + mean_translation = trans.mean(axis=0) + candidate_eye_brow = pts3d[10:, eye_brow_indices] + std_mean_pts3d = np.load(config['dataset_params']['pts3d_path']).mean(axis=0) + # candidates images + img_candidates = [] + for j in range(4): + output = imread(join(data_root, 'candidates', f'normalized_full_{j}.jpg')) + output = A.pytorch.transforms.ToTensor(normalize={'mean':(0.5,0.5,0.5), + 'std':(0.5,0.5,0.5)})(image=output)['image'] + img_candidates.append(output) + img_candidates = torch.cat(img_candidates).unsqueeze(0).to(device) + + # shoulders + shoulders = np.load(join(data_root, 'normalized_shoulder_points.npy')) + shoulder3D = np.load(join(data_root, 'shoulder_points3D.npy'))[1] + ref_trans = trans[1] + + # camera matrix, we always use training set intrinsic parameters. + camera = utils.camera() + camera_intrinsic = np.load(join(data_root, 'camera_intrinsic.npy')).astype(np.float32) + APC_feat_database = np.load(join(data_root, 'APC_feature_base.npy')) + + # load reconstruction data + scale = sio.loadmat(join(data_root, 'id_scale.mat'))['scale'][0,0] + # Audio2Mel_torch = audio_funcs.Audio2Mel(n_fft=512, hop_length=int(16000/120), win_length=int(16000/60), sampling_rate=16000, + # n_mel_channels=80, mel_fmin=90, mel_fmax=7600.0).to(device) + + + + ########################### Experiment Settings ########################### + #### user config + use_LLE = config['model_params']['APC']['use_LLE'] + Knear = config['model_params']['APC']['Knear'] + LLE_percent = config['model_params']['APC']['LLE_percent'] + headpose_sigma = config['model_params']['Headpose']['sigma'] + Feat_smooth_sigma = config['model_params']['Audio2Mouth']['smooth'] + Head_smooth_sigma = config['model_params']['Headpose']['smooth'] + Feat_center_smooth_sigma, Head_center_smooth_sigma = 0, 0 + AMP_method = config['model_params']['Audio2Mouth']['AMP'][0] + Feat_AMPs = config['model_params']['Audio2Mouth']['AMP'][1:] + rot_AMP, trans_AMP = config['model_params']['Headpose']['AMP'] + shoulder_AMP = config['model_params']['Headpose']['shoulder_AMP'] + save_feature_maps = config['model_params']['Image2Image']['save_input'] + + #### common settings + Featopt = FeatureOptions().parse() + Headopt = HeadposeOptions().parse() + Renderopt = RenderOptions().parse() + Featopt.load_epoch = config['model_params']['Audio2Mouth']['ckp_path'] + Headopt.load_epoch = config['model_params']['Headpose']['ckp_path'] + Renderopt.dataroot = config['dataset_params']['root'] + Renderopt.load_epoch = config['model_params']['Image2Image']['ckp_path'] + Renderopt.size = config['model_params']['Image2Image']['size'] + ## GPU or CPU + if opt.device == 'cpu': + Featopt.gpu_ids = Headopt.gpu_ids = Renderopt.gpu_ids = [] + + + + ############################# Load Models ################################# + print('---------- Loading Model: APC-------------') + APC_model = APC_encoder(config['model_params']['APC']['mel_dim'], + config['model_params']['APC']['hidden_size'], + config['model_params']['APC']['num_layers'], + config['model_params']['APC']['residual']) + APC_model.load_state_dict(torch.load(config['model_params']['APC']['ckp_path']), strict=False) + if opt.device == 'cuda': + APC_model.cuda() + APC_model.eval() + print('---------- Loading Model: {} -------------'.format(Featopt.task)) + Audio2Feature = create_model(Featopt) + Audio2Feature.setup(Featopt) + Audio2Feature.eval() + print('---------- Loading Model: {} -------------'.format(Headopt.task)) + Audio2Headpose = create_model(Headopt) + Audio2Headpose.setup(Headopt) + Audio2Headpose.eval() + if Headopt.feature_decoder == 'WaveNet': + if opt.device == 'cuda': + Headopt.A2H_receptive_field = Audio2Headpose.Audio2Headpose.module.WaveNet.receptive_field + else: + Headopt.A2H_receptive_field = Audio2Headpose.Audio2Headpose.WaveNet.receptive_field + print('---------- Loading Model: {} -------------'.format(Renderopt.task)) + facedataset = create_dataset(Renderopt) + Feature2Face = create_model(Renderopt) + Feature2Face.setup(Renderopt) + Feature2Face.eval() + visualizer = Visualizer(Renderopt) + + + + ############################## Inference ################################## + print('Processing audio: {} ...'.format(audio_name)) + # read audio + audio, _ = librosa.load(opt.driving_audio, sr=sr) + total_frames = np.int32(audio.shape[0] / sr * FPS) + + + #### 1. compute APC features + print('1. Computing APC features...') + mel80 = utils.compute_mel_one_sequence(audio, device=opt.device) + mel_nframe = mel80.shape[0] + with torch.no_grad(): + length = torch.Tensor([mel_nframe]) + mel80_torch = torch.from_numpy(mel80.astype(np.float32)).to(device).unsqueeze(0) + hidden_reps = APC_model.forward(mel80_torch, length)[0] # [mel_nframe, 512] + hidden_reps = hidden_reps.cpu().numpy() + audio_feats = hidden_reps + + + #### 2. manifold projection + if use_LLE: + print('2. Manifold projection...') + ind = utils.KNN_with_torch(audio_feats, APC_feat_database, K=Knear) + weights, feat_fuse = utils.compute_LLE_projection_all_frame(audio_feats, APC_feat_database, ind, audio_feats.shape[0]) + audio_feats = audio_feats * (1-LLE_percent) + feat_fuse * LLE_percent + + + #### 3. Audio2Mouth + print('3. Audio2Mouth inference...') + pred_Feat = Audio2Feature.generate_sequences(audio_feats, sr, FPS, fill_zero=True, opt=Featopt) + + + #### 4. Audio2Headpose + print('4. Headpose inference...') + # set history headposes as zero + pre_headpose = np.zeros(Headopt.A2H_wavenet_input_channels, np.float32) + pred_Head = Audio2Headpose.generate_sequences(audio_feats, pre_headpose, fill_zero=True, sigma_scale=0.3, opt=Headopt) + + + #### 5. Post-Processing + print('5. Post-processing...') + nframe = min(pred_Feat.shape[0], pred_Head.shape[0]) + pred_pts3d = np.zeros([nframe, 73, 3]) + pred_pts3d[:, mouth_indices] = pred_Feat.reshape(-1, 25, 3)[:nframe] + + ## mouth + pred_pts3d = utils.landmark_smooth_3d(pred_pts3d, Feat_smooth_sigma, area='only_mouth') + pred_pts3d = utils.mouth_pts_AMP(pred_pts3d, True, AMP_method, Feat_AMPs) + pred_pts3d = pred_pts3d + mean_pts3d + pred_pts3d = utils.solve_intersect_mouth(pred_pts3d) # solve intersect lips if exist + + ## headpose + pred_Head[:, 0:3] *= rot_AMP + pred_Head[:, 3:6] *= trans_AMP + pred_headpose = utils.headpose_smooth(pred_Head[:,:6], Head_smooth_sigma).astype(np.float32) + pred_headpose[:, 3:] += mean_translation + pred_headpose[:, 0] += 180 + + ## compute projected landmarks + pred_landmarks = np.zeros([nframe, 73, 2], dtype=np.float32) + final_pts3d = np.zeros([nframe, 73, 3], dtype=np.float32) + final_pts3d[:] = std_mean_pts3d.copy() + final_pts3d[:, 46:64] = pred_pts3d[:nframe, 46:64] + for k in tqdm(range(nframe)): + ind = k % candidate_eye_brow.shape[0] + final_pts3d[k, eye_brow_indices] = candidate_eye_brow[ind] + mean_pts3d[eye_brow_indices] + pred_landmarks[k], _, _ = utils.project_landmarks(camera_intrinsic, camera.relative_rotation, + camera.relative_translation, scale, + pred_headpose[k], final_pts3d[k]) + + ## Upper Body Motion + pred_shoulders = np.zeros([nframe, 18, 2], dtype=np.float32) + pred_shoulders3D = np.zeros([nframe, 18, 3], dtype=np.float32) + for k in range(nframe): + diff_trans = pred_headpose[k][3:] - ref_trans + pred_shoulders3D[k] = shoulder3D + diff_trans * shoulder_AMP + # project + project = camera_intrinsic.dot(pred_shoulders3D[k].T) + project[:2, :] /= project[2, :] # divide z + pred_shoulders[k] = project[:2, :].T + + + #### 6. Image2Image translation & Save resuls + print('6. Image2Image translation & Saving results...') + for ind in tqdm(range(0, nframe), desc='Image2Image translation inference'): + # feature_map: [input_nc, h, w] + current_pred_feature_map = facedataset.dataset.get_data_test_mode(pred_landmarks[ind], + pred_shoulders[ind], + facedataset.dataset.image_pad) + input_feature_maps = current_pred_feature_map.unsqueeze(0).to(device) + pred_fake = Feature2Face.inference(input_feature_maps, img_candidates) + # save results + visual_list = [('pred', util.tensor2im(pred_fake[0]))] + if save_feature_maps: + visual_list += [('input', np.uint8(current_pred_feature_map[0].cpu().numpy() * 255))] + visuals = OrderedDict(visual_list) + visualizer.save_images(save_root, visuals, str(ind+1)) + + + ## make videos + # generate corresponding audio, reused for all results + tmp_audio_path = join(save_root, 'tmp.wav') + tmp_audio_clip = audio[ : np.int32(nframe * sr / FPS)] + librosa.output.write_wav(tmp_audio_path, tmp_audio_clip, sr) + + + final_path = join(save_root, audio_name + '.avi') + write_video_with_audio(tmp_audio_path, final_path, 'pred_') + feature_maps_path = join(save_root, audio_name + '_feature_maps.avi') + write_video_with_audio(tmp_audio_path, feature_maps_path, 'input_') + + if os.path.exists(tmp_audio_path): + os.remove(tmp_audio_path) + if not opt.save_intermediates: + _img_paths = list(map(lambda x:str(x), list(Path(save_root).glob('*.jpg')))) + for i in tqdm(range(len(_img_paths)), desc='deleting intermediate images'): + os.remove(_img_paths[i]) + + + print('Finish!') + + + + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/doc/Teaser.jpg b/LiveSpeechPortraits/source_code/doc/Teaser.jpg new file mode 100644 index 00000000..e5ddd92d Binary files /dev/null and b/LiveSpeechPortraits/source_code/doc/Teaser.jpg differ diff --git a/LiveSpeechPortraits/source_code/entry_point.py b/LiveSpeechPortraits/source_code/entry_point.py new file mode 100644 index 00000000..a0affd05 --- /dev/null +++ b/LiveSpeechPortraits/source_code/entry_point.py @@ -0,0 +1,34 @@ +import subprocess +import argparse + +def run_lspmodel(args): + # 构建命令,包含 lspmodel 后面的所有参数 + command = ["python", "demo.py"] + args + subprocess.run(command) + +def run_eval(args): + # 构建命令,先进入 judge_models 目录,然后运行 run_judge.py + command = ["cd", "judge_models", "&&", "python", "run_judge.py"] + args + subprocess.run(command, shell=True) + +def main(): + # 设置命令行参数解析器 + parser = argparse.ArgumentParser(description="运行不同模式的脚本") + + # --lspmodel 和 --eval 是选项 + parser.add_argument("--lspmodel", action="store_true", help="运行 lspmodel 模式") + parser.add_argument("--eval", action="store_true", help="运行 eval 模式") + + # 使用 parse_known_args 解析已知参数和多余参数 + args, unknown_args = parser.parse_known_args() + + # 根据 --lspmodel 或 --eval 执行不同的命令 + if args.lspmodel: + run_lspmodel(unknown_args) + elif args.eval: + run_eval(unknown_args) + else: + print("Please specify --lspmodel or --eval to decide the mode.") + +if __name__ == "__main__": + main() diff --git a/LiveSpeechPortraits/source_code/funcs/audio_funcs.py b/LiveSpeechPortraits/source_code/funcs/audio_funcs.py new file mode 100644 index 00000000..c6917734 --- /dev/null +++ b/LiveSpeechPortraits/source_code/funcs/audio_funcs.py @@ -0,0 +1,437 @@ +import os +import os.path +import math +# import sox +#import pyworld as pw +import torch +import torch.utils.data +import numpy as np +import librosa + + +""" +useage +fft = Audio2Mel().cuda() +# audio shape is B x 1 x T, the normalized mel shape is B x D x T +mel = fft(audio) +""" +from librosa.filters import mel as librosa_mel_fn +import torch.nn.functional as F +class Audio2Mel(torch.nn.Module): + def __init__( + self, + n_fft=512, + hop_length=256, + win_length=1024, + sampling_rate=16000, + n_mel_channels=80, + mel_fmin=90, + mel_fmax=7600.0, + ): + super(Audio2Mel, self).__init__() + ############################################## + # FFT Parameters # + ############################################## + window = torch.hann_window(win_length).float() + mel_basis = librosa_mel_fn( + sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + self.register_buffer("window", window) + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + self.min_mel = math.log(1e-5) + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + + """ + input audio signal (-1,1): B x 1 x T + output mel signal: B x D x T', T' is a reduction of T + """ + + def forward(self, audio, normalize=True): + p = (self.n_fft - self.hop_length) // 2 + audio = F.pad(audio, (p, p), "reflect").squeeze(1) + fft = torch.stft( + audio, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=False, + ) + real_part, imag_part = fft.unbind(-1) + magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) + mel_output = torch.matmul(self.mel_basis, magnitude) + log_mel_spec = torch.log(torch.clamp(mel_output, min=1e-5)) + + # normalize to the range [0,1] + if normalize: + log_mel_spec = (log_mel_spec - self.min_mel) / -self.min_mel + return log_mel_spec + + def mel_to_audio(self, mel): + mel = torch.exp(mel * (-self.min_mel) + self.min_mel) ** 2 + mel_np = mel.cpu().numpy() + audio = librosa.feature.inverse.mel_to_audio(mel_np, sr=self.sampling_rate, n_fft=self.n_fft, + hop_length=self.hop_length, win_length=self.win_length, + window='hann', center=False, + pad_mode='reflect', power=2.0, n_iter=32, fmin=self.mel_fmin, + fmax=self.mel_fmax) + return audio + + """ + here we will get per frame energy to replace mc0 in the corresponding prosody representation + the audio is already in the gpu card for accerelate the computation speed + input audio signal: B x 1 x T + output energy: B x 1 x T' + """ + + def get_energy(self, audio, normalize=True): + # B x 1 x T + p = (self.n_fft - self.hop_length) // 2 + audio_new = F.pad(audio, (p, p), "reflect").squeeze(1) + # audio_new = audio.squeeze(1) + audio_fold = audio_new.unfold(1, self.win_length, self.hop_length) + audio_energy = torch.sqrt(torch.mean(audio_fold ** 2, dim=-1)) + audio_energy = torch.log(torch.clamp(audio_energy, min=1e-5)) + if normalize: + audio_energy = (audio_energy - self.min_mel) / -self.min_mel + return audio_energy + + # we can get the energy of mels here, B*D*T + def get_energy_mel(self, mels, normalize=True): + m = mels.exp().mean(dim=1) + audio_energy = torch.log(m) + # audio_energy = torch.log(torch.clamp(m,min=1e-5)) + # if normalize: + # audio_energy = (audio_energy - self.min_mel) / -self.min_mel + return audio_energy + + + + +def mu_law_encoding(data, mu=255): + '''encode the original audio via mu-law companding and mu-bits quantization + ''' + # mu-law companding + mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1) + # mu-bits quantization from [-1, 1] to [0, mu] + mu_x = (mu_x + 1) / 2 * mu + 0.5 + return mu_x.astype(np.int32) + +#%timeit mu_x = mu_law_encoding(x, 255) 305 µs ± 554 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) + + +def mu_law_decoding(data, mu=255): + '''inverse the mu-law compressed and quantized data. + ''' + # dequantization + y = 2 * (data.astype(np.float32) / mu) - 1 + # inverse mu-law companding + x = np.sign(y) * (1.0 / mu) * ((1.0 + mu)**abs(y) - 1.0) + return x + + + +## audio augmentation +def inject_gaussian_noise(data, noise_factor, use_torch=False): + ''' inject random gaussian noise (mean=0, std=1) to audio clip + In my test, a reasonable factor region could be [0, 0.01] + larger will be too large and smaller could be ignored. + Args: + data: [n,] original audio sequence + noise_factor(float): scaled factor + use_torch(bool): optional, if use_torch=True, input data and implementation will + be torch methods. + Returns: + augmented_data: [n,] noised audio clip + + ''' + if use_torch == False: + augmented_data = data + noise_factor * np.random.normal(0, 1, len(data)) + # Cast back to same data type + augmented_data = augmented_data.astype(type(data[0])) + # use torch + else: + augmented_data = data + noise_factor * torch.randn(1).cuda() + + return augmented_data + + +# pitch shifting +def pitch_shifting(data, sampling_rate=48000, factor=5): + ''' shift the audio pitch. + ''' + # Permissible factor values = -5 <= x <= 5 + pitch_factor = np.random.rand(1) * 2 * factor - factor + return librosa.effects.pitch_shift(data, sampling_rate, pitch_factor) + + +def speed_change(data, landmark=None): + ''' change the speed of input audio. Note that we return the speed_rate to + change the speed of landmarks or videos. + Args: + data: [n,] audio clip + landmark: [m, pts, 2] aligned landmarks with audio if existed. + ''' + # Permissible factor values = 0.7 <= x <= 1.3 (higher is faster) + # resulted audio length: np.round(n/rate) + speed_rate = np.random.uniform(0.7, 1.3) + # only augment audio + if landmark == None: + return librosa.effects.time_stretch(data, speed_rate), speed_rate + else: +# n_after = np.round(data.shape[0]/speed_rate) + pass + + + + +def world_augment(wav, sr, op): + f0, sp, ap = pw.wav2world(wav.astype(np.float64), sr) + op = op if op is not None else np.random.randint(0,4) + if op == 0: + base_f0 = np.random.randint(100,300) +# base_f0 = np.random.randint(100, 200) + robot_like_f0 = np.ones_like(f0) * base_f0 # 100是个适当的数字 + robot_like = pw.synthesize(robot_like_f0, sp, ap, sr) + out_wav = robot_like + elif op == 1: + ratio = 1 + np.random.rand() + female_like_sp = np.zeros_like(sp) + for f in range(female_like_sp.shape[1]): + female_like_sp[:, f] = sp[:, int(f/ratio)] + ratio_f = 0.65 + 1.4 * np.random.rand() + out_wav = pw.synthesize(f0*ratio_f, female_like_sp, ap, sr) + elif op == 2: + # change the current pitch here + ratio = 0.65 + 1.4 * np.random.rand() + out_wav = pw.synthesize(f0*ratio, sp, ap, sr) + elif op == 3: + # the random masking using the time axis + mask_len = np.random.randint(0,256 * 4) + mask_pos = np.random.randint(0, wav.shape[0] - mask_len + 1) + out_wav = np.copy(wav) + out_wav[mask_pos:mask_pos+mask_len] = 0 + else: + out_wav = np.copy(wav) + + return out_wav.astype(np.float32) + + +def sox_augment(wav, sr, tempo_ratio=1.0, op=None): + aug_choice = op if op is not None else np.random.randint(low=1, high=8) +# tempo_ratio = 1.0 + hop_length = 256 + tfm = sox.Transformer() + if aug_choice == 1: + # 1 pitch aug + param = np.random.uniform(-5.0, 5.0) + tfm.pitch(param) + elif aug_choice == 2: + # 2 tempo aug, when tempo_ratio is around 1.0, no tempo aug +# tempo_ratio = np.random.uniform(0.5, 2.0) +# if tempo_ratio >= 0.9 and tempo_ratio <= 1.1: +# tempo_ratio = 1.0 +# if tempo_ratio != 1.0: +# tfm.tempo(tempo_ratio, 's', quick=False) + pass + elif aug_choice == 3: + # 3 gain aug + param = np.random.uniform(-20, 5) + tfm.norm() + tfm.gain(param) + elif aug_choice == 4: + # 4 echo aug + # delays = np.random.uniform(5, 60) + # decays = np.random.uniform(0.2, 0.6) + # tfm.echo(delays=[delays], decays=[decays]) + pass + elif aug_choice == 5: + # 5 reverb aug + param = np.random.uniform(0, 100, size=(4,)) + tfm.reverb(param[0], param[1], param[2], param[3]) + elif aug_choice == 6: + # 6 bandreject aug + param1 = np.random.randint(200, 7500) + param2 = np.random.uniform(1.0, 4.0) + tfm.bandreject(param1, param2) + elif aug_choice == 7: + # 8 equalizer aug + param1 = np.random.randint(200, 7500) + param2 = np.random.uniform(1.0, 4.0) + param3 = np.random.uniform(-20, 5) + tfm.equalizer(param1, param2, param3) + else: + raise RuntimeError('Aug choice error!') + + wave_length = wav.shape[0] + if aug_choice == 1:# When using pitch augmentation, pad silence to keep audio length + wav = np.concatenate((wav, np.array([0.0]*(hop_length * 2))), axis=0) + + aug_wave_data = tfm.build_array(input_array=wav, sample_rate_in=sr) + + if aug_choice == 1:# Keep audio length unchanged when using pitch augmentation + aug_wave_data = aug_wave_data[:wave_length] + + return aug_wave_data + + +def sox_augment_v2(wav, sr, op=None): + aug_choice = op if op is not None else np.random.randint(low=1, high=5) + hop_length = 256 + tfm = sox.Transformer() + if aug_choice == 1: + # 1 pitch aug + param = np.random.uniform(-5.0, 5.0) + tfm.pitch(param) + elif aug_choice == 2: + # 5 reverb aug + param = np.random.uniform(0, 100, size=(4,)) + tfm.reverb(param[0], param[1], param[2], param[3]) + elif aug_choice == 3: + # 6 bandreject aug + param1 = np.random.randint(200, 7500) + param2 = np.random.uniform(1.0, 4.0) + tfm.bandreject(param1, param2) + elif aug_choice == 4: + # 8 equalizer aug + param1 = np.random.randint(200, 7500) + param2 = np.random.uniform(1.0, 4.0) + param3 = np.random.uniform(-20, 5) + tfm.equalizer(param1, param2, param3) + else: + raise RuntimeError('Aug choice error!') + + wave_length = wav.shape[0] + if aug_choice == 1:# When using pitch augmentation, pad silence to keep audio length + wav = np.concatenate((wav, np.array([0.0]*(hop_length * 2))), axis=0) + + aug_wave_data = tfm.build_array(input_array=wav, sample_rate_in=sr) + + if aug_choice == 1:# Keep audio length unchanged when using pitch augmentation + aug_wave_data = aug_wave_data[:wave_length] + + return aug_wave_data + + +def audio_output_augment(wav, sr, op=None): + aug_choice = op if op is not None else np.random.randint(low=1, high=4) + tfm = sox.Transformer() + if aug_choice == 1: + # 5 reverb aug + param = np.random.uniform(0, 100, size=(4,)) + tfm.reverb(param[0], param[1], param[2], param[3]) + elif aug_choice == 2: + # 6 bandreject aug + param1 = np.random.randint(200, 7500) + param2 = np.random.uniform(1.0, 4.0) + tfm.bandreject(param1, param2) + elif aug_choice == 3: + # 8 equalizer aug + param1 = np.random.randint(200, 7500) + param2 = np.random.uniform(1.0, 4.0) + param3 = np.random.uniform(-20, 5) + tfm.equalizer(param1, param2, param3) + else: + raise RuntimeError('Aug choice error!') + + aug_wave_data = tfm.build_array(input_array=wav, sample_rate_in=sr) + + return aug_wave_data + + +def audio_time_augment(wav, sr, time_scale): + tfm = sox.Transformer() + tfm.tempo(time_scale, 's', quick=False) + + aug_wave_data = tfm.build_array(input_array=wav, sample_rate_in=sr) + + return aug_wave_data + + +def prepare_noises(scp_file, root=None, sampline_rate=None, ignore_class=None): + noises = [] + print('Loading augmentation noises...') + with open(scp_file,'r') as fp: + for line in fp.readlines(): + line = line.rstrip('\n') + if ignore_class is not None and ignore_class in line: + continue + + noise, sr = librosa.load(os.path.join(root, line), sr=sampline_rate) + noises.append(noise) + print('Augmentation noises loaded!') + return noises, sr + + +def add_gauss_noise(wav, noise_std=0.03, max_wav_value=1.0): + if isinstance(wav, np.ndarray): + wav = torch.tensor(wav.copy()) + + real_std = np.random.random() * noise_std + wav_new = wav.float() / max_wav_value + torch.randn(wav.size()) * real_std + wav_new = wav_new * max_wav_value + wav_new = wav_new.clamp_(-max_wav_value, max_wav_value) + + return wav_new.float().numpy() + +def add_background_noise(wav, noises, min_snr=2, max_snr=15): + def mix_noise(wav, noise, scale): + x = wav + scale * noise + x = x.clip(-1, 1) + return x + + def voice_energy(wav): + wav_float = np.copy(wav) + return np.sum(wav_float ** 2) / (wav_float.shape[0] + 1e-5) + + def voice_energy_ratio(wav, noise, target_snr): + wav_eng = voice_energy(wav) + noise_eng = voice_energy(noise) + target_noise_eng = wav_eng / (10 ** (target_snr / 10.0)) + ratio = target_noise_eng / (noise_eng + 1e-5) + return ratio + + total_id = len(noises) + # 0 is no need to generate the noise + idx = np.random.choice(range(0, total_id)) + noise_wav = noises[idx] + if noise_wav.shape[0] > wav.shape[0]: + sel_range_id = np.random.choice(range(0, noise_wav.shape[0] - wav.shape[0])) + n = noise_wav[sel_range_id:sel_range_id + wav.shape[0]] + else: + n = np.zeros(wav.shape[0]) + sel_range_id = np.random.choice(range(0, wav.shape[0] - noise_wav.shape[0] + 1)) + n[sel_range_id:sel_range_id + noise_wav.shape[0]] = noise_wav + # + target_snr = np.random.random() * (max_snr - min_snr) + min_snr + scale = voice_energy_ratio(wav, n, target_snr) + wav_new = mix_noise(wav, n, scale) + return wav_new + + +def noise_augment(wav, wav_noises, gaussian_prob=0.5): + if np.random.random() > gaussian_prob:# add gauss noise + noise_std = np.random.uniform(low=0.001, high=0.02) + aug_wave_data = add_gauss_noise(wav, noise_std=noise_std) + else:# add background noise + aug_wave_data = add_background_noise(wav, wav_noises, min_snr=2, max_snr=15) + + return aug_wave_data + + + + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/funcs/utils.py b/LiveSpeechPortraits/source_code/funcs/utils.py new file mode 100644 index 00000000..9db9d90c --- /dev/null +++ b/LiveSpeechPortraits/source_code/funcs/utils.py @@ -0,0 +1,375 @@ +import sys +sys.path.append("..") +from . import audio_funcs + +import numpy as np +from math import cos, sin +import torch +from numpy.linalg import solve +from scipy.ndimage import gaussian_filter1d +from sklearn.neighbors import KDTree +import time +from tqdm import tqdm + + +class camera(object): + def __init__(self, fx=0, fy=0, cx=0, cy=0): + self.name = 'default camera' + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + self.relative_rotation = np.diag([1,1,1]).astype(np.float32) + self.relative_translation = np.zeros(3, dtype=np.float32) +# self.intrinsic = np.array([[self.fx, 0, self.cx], +# [0, self.fy, self.cy], +# [0, 0, 1]]) + + def intrinsic(self, trans_matrix=0): + ''' compute the intrinsic matrix + ''' + intrinsic = np.array([[self.fx, 0, self.cx], + [0, self.fy, self.cy], + [0, 0, 1]]) + + return intrinsic + + def relative(self): + ''' compute the relative transformation 4x4 matrix with respect to the + first camera kinect. specially the kinect's relative transformation + matrix is exact a identity matrix. + ''' + relative = np.eye(4, dtype=np.float32) + relative[:3, :3] = self.relative_rotation + relative[:3, 3] = self.relative_translation + + return relative + + def transform_intrinsic(self, transform_matrix): + ''' change the camera intrinsic matrix + transformed_intrinsic = transform_matrix * intrinsic + ''' + scale = transform_matrix[0,0] + self.fx *= scale + self.fy *= scale + self.cx = scale * self.cx + transform_matrix[0, 2] + self.cy = scale * self.cy + transform_matrix[1, 2] + + + + +def compute_mel_one_sequence(audio, hop_length=int(16000/120), winlen=1/60, winstep=0.5/60, sr=16000, fps=60, device='cpu'): + ''' compute mel for an audio sequence. + ''' + device = torch.device(device) + Audio2Mel_torch = audio_funcs.Audio2Mel(n_fft=512, hop_length=int(16000/120), win_length=int(16000/60), sampling_rate=16000, + n_mel_channels=80, mel_fmin=90, mel_fmax=7600.0).to(device) + + nframe = int(audio.shape[0] / 16000 * 60) + mel_nframe = 2 * nframe + mel_frame_len = int(sr * winlen) + mel_frame_step = sr * winstep + + mel80s = np.zeros([mel_nframe, 80]) + for i in range(mel_nframe): +# for i in tqdm(range(mel_nframe)): + st = int(i * mel_frame_step) + audio_clip = audio[st : st + mel_frame_len] + if len(audio_clip) < mel_frame_len: + audio_clip = np.concatenate([audio_clip, np.zeros([mel_frame_len - len(audio_clip)])]) + audio_clip_device = torch.from_numpy(audio_clip).unsqueeze(0).unsqueeze(0).to(device).float() + mel80s[i] = Audio2Mel_torch(audio_clip_device).cpu().numpy()[0].T # [1, 80] + + return mel80s + + + +def KNN(feats, feat_database, K=10): + ''' compute KNN for feat in feat base + ''' + tree = KDTree(feat_database, leaf_size=100000) + print('start computing KNN ...') + st = time.time() + dist, ind = tree.query(feats, k=K) + et = time.time() + print('Taken time: ', et-st) + + return dist, ind + + +def KNN_with_torch(feats, feat_database, K=10): + feats = torch.from_numpy(feats)#.cuda() + feat_database = torch.from_numpy(feat_database)#.cuda() + # Training + feat_base_norm = (feat_database ** 2).sum(-1) +# print('start computing KNN ...') +# st = time.time() + feats_norm = (feats ** 2).sum(-1) + diss = (feats_norm.view(-1, 1) + + feat_base_norm.view(1, -1) + - 2 * feats @ feat_database.t() # Rely on cuBLAS for better performance! + ) + ind = diss.topk(K, dim=1, largest=False).indices +# et = time.time() +# print('Taken time: ', et-st) + + return ind.cpu().numpy() + + + + +def solve_LLE_projection(feat, feat_base): + '''find LLE projection weights given feat base and target feat + Args: + feat: [ndim, ] target feat + feat_base: [K, ndim] K-nearest feat base + ======================================= + We need to solve the following function + ``` + min|| feat - \sum_0^k{w_i} * feat_base_i ||, s.t. \sum_0^k{w_i}=1 + ``` + equals to: + ft = w1*f1 + w2*f2 + ... + wk*fk, s.t. w1+w2+...+wk=1 + = (1-w2-...-wk)*f1 + w2*f2 + ... + wk*fk + ft-f1 = w2*(f2-f1) + w3*(f3-f1) + ... + wk*(fk-f1) + ft-f1 = (f2-f1, f3-f1, ..., fk-f1) dot (w2, w3, ..., wk).T + B = A dot w_, here, B: [ndim,] A: [ndim, k-1], w_: [k-1,] + Finally, + ft' = (1-w2-..wk, w2, ..., wk) dot (f1, f2, ..., fk) + ======================================= + Returns: + w: [K,] linear weights, sums to 1 + ft': [ndim,] reconstructed feats + ''' + K, ndim = feat_base.shape + if K == 1: + feat_fuse = feat_base[0] + w = np.array([1]) + else: + w = np.zeros(K) + B = feat - feat_base[0] # [ndim,] + A = (feat_base[1:] - feat_base[0]).T # [ndim, K-1] + AT = A.T + w[1:] = solve(AT.dot(A), AT.dot(B)) + w[0] = 1 - w[1:].sum() + feat_fuse = w.dot(feat_base) + + return w, feat_fuse + + + +def compute_LLE_projection_frame(feats, feat_database, ind): + nframe = feats.shape[0] + feat_fuse = np.zeros_like(feats) + w = np.zeros([nframe, ind.shape[1]]) + current_K_feats = feat_database[ind] + w, feat_fuse = solve_LLE_projection(feats, current_K_feats) + + return w, feat_fuse + + +def compute_LLE_projection_all_frame(feats, feat_database, ind, nframe): + nframe = feats.shape[0] + feat_fuse = np.zeros_like(feats) + w = np.zeros([nframe, ind.shape[1]]) + for i in tqdm(range(nframe), desc='LLE projection'): + current_K_feats = feat_database[ind[i]] + w[i], feat_fuse[i] = solve_LLE_projection(feats[i], current_K_feats) + + return w, feat_fuse + + +def angle2matrix(angles, gradient='false'): + ''' get rotation matrix from three rotation angles(degree). right-handed. + Args: + angles: [3,]. x, y, z angles + x: pitch. positive for looking down. + y: yaw. positive for looking left. + z: roll. positive for tilting head right. + gradient(str): whether to compute gradient matrix: dR/d_x,y,z + Returns: + R: [3, 3]. rotation matrix. + ''' + x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2]) + # x + Rx=np.array([[1, 0, 0], + [0, cos(x), -sin(x)], + [0, sin(x), cos(x)]]) + # y + Ry=np.array([[ cos(y), 0, sin(y)], + [ 0, 1, 0], + [-sin(y), 0, cos(y)]]) + # z + Rz=np.array([[cos(z), -sin(z), 0], + [sin(z), cos(z), 0], + [ 0, 0, 1]]) + + R=Rz.dot(Ry.dot(Rx)) + #R=Rx.dot(Ry.dot(Rz)) + + if gradient != 'true': + return R.astype(np.float32) + elif gradient == 'true': + # gradident matrix + dRxdx = np.array([[0, 0, 0], + [0, -sin(x), -cos(x)], + [0, cos(x), -sin(x)]]) + dRdx = Rz.dot(Ry.dot(dRxdx)) * np.pi/180 + dRydy = np.array([[-sin(y), 0, cos(y)], + [ 0, 0, 0], + [-cos(y), 0, -sin(y)]]) + dRdy = Rz.dot(dRydy.dot(Rx)) * np.pi/180 + dRzdz = np.array([[-sin(z), -cos(z), 0], + [ cos(z), -sin(z), 0], + [ 0, 0, 0]]) + dRdz = dRzdz.dot(Ry.dot(Rx)) * np.pi/180 + + return R.astype(np.float32), [dRdx.astype(np.float32), dRdy.astype(np.float32), dRdz.astype(np.float32)] + + + +def project_landmarks(camera_intrinsic, viewpoint_R, viewpoint_T, scale, headposes, pts_3d): + ''' project 2d landmarks given predicted 3d landmarks & headposes and user-defined + camera & viewpoint parameters + ''' + rot, trans = angle2matrix(headposes[:3]), headposes[3:][:, None] + pts3d_headpose = scale * rot.dot(pts_3d.T) + trans + pts3d_viewpoint = viewpoint_R.dot(pts3d_headpose) + viewpoint_T[:, None] + pts2d_project = camera_intrinsic.dot(pts3d_viewpoint) + pts2d_project[:2, :] /= pts2d_project[2, :] # divide z + pts2d_project = pts2d_project[:2, :].T + + return pts2d_project, rot, trans + + + +def landmark_smooth_3d(pts3d, smooth_sigma=0, area='only_mouth'): + ''' smooth the input 3d landmarks using gaussian filters on each dimension. + Args: + pts3d: [N, 73, 3] + ''' + # per-landmark smooth + if not smooth_sigma == 0: + if area == 'all': + pts3d = gaussian_filter1d(pts3d.reshape(-1, 73*3), smooth_sigma, axis=0).reshape(-1, 73, 3) + elif area == 'only_mouth': + mouth_pts3d = pts3d[:, 46:64, :].copy() + mouth_pts3d = gaussian_filter1d(mouth_pts3d.reshape(-1, 18*3), smooth_sigma, axis=0).reshape(-1, 18, 3) + pts3d = gaussian_filter1d(pts3d.reshape(-1, 73*3), smooth_sigma, axis=0).reshape(-1, 73, 3) + pts3d[:, 46:64, :] = mouth_pts3d + + + + return pts3d + + + +mouth_indices = list(range(46 * 2, 64 * 2)) +upper_outer_lip = list(range(47, 52)) +upper_inner_lip = [63, 62, 61] +lower_inner_lip = [58, 59, 60] +lower_outer_lip = list(range(57, 52, -1)) +lower_mouth = [53, 54, 55, 56, 57, 58, 59, 60] +upper_mouth = [46, 47, 48, 49, 50, 51, 52, 61, 62, 63] +def mouth_pts_AMP(pts3d, is_delta=True, method='XY', paras=[1,1]): + ''' mouth region AMP to control the reaction amplitude. + method: 'XY', 'delta', 'XYZ', 'LowerMore' or 'CloseSmall' + ''' + if method == 'XY': + AMP_scale_x, AMP_scale_y = paras + if is_delta: + pts3d[:, 46:64, 0] *= AMP_scale_x + pts3d[:, 46:64, 1] *= AMP_scale_y + else: + mean_mouth3d_xy = pts3d[:, 46:64, :2].mean(axis=0) + pts3d[:, 46:64, 0] += (AMP_scale_x-1) * (pts3d[:, 46:64, 0] - mean_mouth3d_xy[:,0]) + pts3d[:, 46:64, 1] += (AMP_scale_y-1) * (pts3d[:, 46:64, 1] - mean_mouth3d_xy[:,1]) + elif method == 'delta': + AMP_scale_x, AMP_scale_y = paras + if is_delta: + diff = AMP_scale_x * (pts3d[1:, 46:64] - pts3d[:-1, 46:64]) + pts3d[1:, 46:64] += diff + + elif method == 'XYZ': + AMP_scale_x, AMP_scale_y, AMP_scale_z = paras + if is_delta: + pts3d[:, 46:64, 0] *= AMP_scale_x + pts3d[:, 46:64, 1] *= AMP_scale_y + pts3d[:, 46:64, 2] *= AMP_scale_z + + elif method == 'LowerMore': + upper_x, upper_y, upper_z, lower_x, lower_y, lower_z = paras + if is_delta: + pts3d[:, upper_mouth, 0] *= upper_x + pts3d[:, upper_mouth, 1] *= upper_y + pts3d[:, upper_mouth, 2] *= upper_z + pts3d[:, lower_mouth, 0] *= lower_x + pts3d[:, lower_mouth, 1] *= lower_y + pts3d[:, lower_mouth, 2] *= lower_z + + elif method == 'CloseSmall': + open_x, open_y, open_z, close_x, close_y, close_z = paras + nframe = pts3d.shape[0] + for i in tqdm(range(nframe), desc='AMP mouth..'): + if sum(pts3d[i, upper_mouth, 1] > 0) + sum(pts3d[i, lower_mouth, 1] < 0) > 16 * 0.3: + # open + pts3d[i, 46:64, 0] *= open_x + pts3d[i, 46:64, 1] *= open_y + pts3d[i, 46:64, 2] *= open_z + else: + # close + pts3d[:, 46:64, 0] *= close_x + pts3d[:, 46:64, 1] *= close_y + pts3d[:, 46:64, 2] *= close_z + + return pts3d + + + + +def solve_intersect_mouth(pts3d): + ''' solve the generated intersec lips, usually happens in mouth AMP usage. + Args: + pts3d: [N, 73, 3] + ''' + upper_inner = pts3d[:, upper_inner_lip] + lower_inner = pts3d[:, lower_inner_lip] + + lower_inner_y = lower_inner[:,:,1] + upper_inner_y = upper_inner[:,:,1] + # all three inner lip flip + flip = lower_inner_y > upper_inner_y + flip = np.where(flip.sum(axis=1) == 3)[0] + + # flip frames + inner_y_diff = lower_inner_y[flip] - upper_inner_y[flip] + half_inner_y_diff = inner_y_diff * 0.5 + # upper inner + pts3d[flip[:,None], upper_inner_lip, 1] += half_inner_y_diff + # lower inner + pts3d[flip[:,None], lower_inner_lip, 1] -= half_inner_y_diff + # upper outer + pts3d[flip[:,None], upper_outer_lip, 1] += half_inner_y_diff.mean() + # lower outer + pts3d[flip[:,None], lower_outer_lip, 1] -= half_inner_y_diff.mean() + + + return pts3d + + + +def headpose_smooth(headpose, smooth_sigmas=[0,0], method='gaussian'): + rot_sigma, trans_sigma = smooth_sigmas + rot = gaussian_filter1d(headpose.reshape(-1, 6)[:,:3], rot_sigma, axis=0).reshape(-1, 3) + trans = gaussian_filter1d(headpose.reshape(-1, 6)[:,3:], trans_sigma, axis=0).reshape(-1, 3) + headpose_smooth = np.concatenate([rot, trans], axis=1) + + return headpose_smooth + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/judge_models/FID/fid_eval.py b/LiveSpeechPortraits/source_code/judge_models/FID/fid_eval.py new file mode 100644 index 00000000..74d1d1d2 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/FID/fid_eval.py @@ -0,0 +1,147 @@ +import cv2 +import os +import torch +import shutil +from pytorch_fid import fid_score # 新增引入pytorch-fid + +# 获取视频信息:帧数与持续时间 +def get_video_info(video_path): + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise Exception(f"Cannot open video: {video_path}") + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + duration = frame_count / fps if fps > 0 else 0 + cap.release() + return frame_count, duration + +# 根据目标FPS进行视频帧率调整 +def adjust_fps(input_video, output_video, target_fps, target_frame_count=None): + print(f"Adjusting FPS for {input_video} to {target_fps} FPS...") + cap = cv2.VideoCapture(input_video) + if not cap.isOpened(): + raise Exception(f"Cannot open video: {input_video}") + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_fps = cap.get(cv2.CAP_PROP_FPS) + fourcc = cv2.VideoWriter_fourcc(*'XVID') + + out = cv2.VideoWriter(output_video, fourcc, target_fps, (width, height)) + frame_interval = int(original_fps / target_fps) if target_fps < original_fps else 1 + + frame_count = 0 + extracted_frame_count = 0 + last_frame = None + + while True: + ret, frame = cap.read() + if not ret: + print(f"Warning: Failed to read frame {frame_count}. Using last valid frame.") + if last_frame is not None and (target_frame_count is None or extracted_frame_count < target_frame_count): + out.write(last_frame) + extracted_frame_count += 1 + if target_frame_count is not None and extracted_frame_count >= target_frame_count: + break + frame_count += 1 + continue + + if frame_count % frame_interval == 0: + out.write(frame) + extracted_frame_count += 1 + last_frame = frame + + if target_frame_count is not None and extracted_frame_count >= target_frame_count: + break + + frame_count += 1 + + # 如果需要,使用最后一帧填补 + while target_frame_count is not None and extracted_frame_count < target_frame_count: + if last_frame is not None: + out.write(last_frame) + extracted_frame_count += 1 + else: + print("Error: No valid frames available to pad the output.") + + cap.release() + out.release() + print(f"FPS adjustment completed: {input_video} -> {output_video}. Extracted {extracted_frame_count} frames.") + return extracted_frame_count + +# 将视频转为帧 +def video_to_frames(video_path, output_dir, target_frame_count=None): + if os.path.exists(output_dir) and os.listdir(output_dir): + print(f"Frames already exist in {output_dir}, skipping extraction.") + return len(os.listdir(output_dir)) + + print(f"Extracting frames from {video_path}...") + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise Exception(f"Cannot open video: {video_path}") + os.makedirs(output_dir, exist_ok=True) + frame_idx = 0 + extracted_frame_count = 0 + while True: + ret, frame = cap.read() + if not ret: + break + frame_path = os.path.join(output_dir, f"frame_{frame_idx:04d}.png") + cv2.imwrite(frame_path, frame) + extracted_frame_count += 1 + frame_idx += 1 + if target_frame_count is not None and extracted_frame_count >= target_frame_count: + break + cap.release() + print(f"Frame extraction completed. Total frames: {extracted_frame_count}") + return extracted_frame_count + +# 这里原本有InceptionV3FeatureExtractor和preprocess_image函数,现在不再需要。 + +# 使用pytorch-fid计算FID +def compute_fid(gt_frames_dir, gen_frames_dir): + # 使用pytorch_fid提供的calculate_fid_given_paths计算FID + fid = fid_score.calculate_fid_given_paths( + [gt_frames_dir, gen_frames_dir], + batch_size=50, # 可根据情况调整 + device='cuda:0' if torch.cuda.is_available() else 'cpu', + dims=2048 # 与Inception V3特征维度一致 + ) + return fid + +# 主流程,只保留FID计算 +def compute_fid_for_videos(gt_video, gen_video, output_dir="./fid_evaluation_output", target_fps=30): + os.makedirs(output_dir, exist_ok=True) + # 获取视频信息 + gt_frame_count, gt_duration = get_video_info(gt_video) + gen_frame_count, gen_duration = get_video_info(gen_video) + + target_frame_count = min(gt_frame_count, gen_frame_count) + target_fps_gt = target_frame_count / gt_duration if gt_duration > 0 else target_fps + target_fps_gen = target_frame_count / gen_duration if gen_duration > 0 else target_fps + + # 调整视频帧数和FPS + adjusted_gt_video = os.path.join(output_dir, "adjusted_gt_video.avi") + gt_frame_count = adjust_fps(gt_video, adjusted_gt_video, target_fps_gt, target_frame_count) + + adjusted_gen_video = os.path.join(output_dir, "adjusted_gen_video.avi") + gen_frame_count = adjust_fps(gen_video, adjusted_gen_video, target_fps_gen, target_frame_count) + + # 提取帧 + gt_frames_dir = os.path.join(output_dir, "ground_truth") + gt_frame_count = video_to_frames(adjusted_gt_video, gt_frames_dir, target_frame_count) + + gen_frames_dir = os.path.join(output_dir, "generated") + gen_frame_count = video_to_frames(adjusted_gen_video, gen_frames_dir, target_frame_count) + + # 使用pytorch_fid计算FID + fid = compute_fid(gt_frames_dir, gen_frames_dir) + shutil.rmtree(output_dir) + return fid + +if __name__ == "__main__": + gt_video_path = "E:\\Code\\DesktopCode\\LiveSpeechPortraits\\data\\Input\\May_short.mp4" + gen_video_path = "E:\\Code\\DesktopCode\\LiveSpeechPortraits\\results\\May\\May_short\\May_short.avi" + + fid_value = compute_fid_for_videos(gt_video_path, gen_video_path) + print(f"FID: {fid_value}") diff --git a/LiveSpeechPortraits/source_code/judge_models/L1_PSNR_SSIM_LPIPS/eval.py b/LiveSpeechPortraits/source_code/judge_models/L1_PSNR_SSIM_LPIPS/eval.py new file mode 100644 index 00000000..653e4488 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/L1_PSNR_SSIM_LPIPS/eval.py @@ -0,0 +1,181 @@ +import cv2 +import os +import numpy as np +import torch +import lpips +from tqdm import tqdm # Progress bar +import math +from torchmetrics.functional import structural_similarity_index_measure as torch_ssim + +# Helper function: get total frame count and duration of a video +def get_video_info(video_path): + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise Exception(f"Cannot open video: {video_path}") + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + duration = frame_count / fps if fps > 0 else 0 + cap.release() + return frame_count, duration + +# Helper function: get frame indices to extract based on target FPS +def get_frame_indices(original_fps, target_fps, total_frames): + if target_fps >= original_fps: + return list(range(total_frames)) + frame_interval = original_fps / target_fps + indices = [int(i * frame_interval) for i in range(int(target_fps * (total_frames / original_fps)))] + return indices + +# Helper function: compute L1 Loss +def compute_l1_loss(gt_frame, gen_frame): + return np.mean(np.abs(gt_frame - gen_frame)) + +# Helper function: compute PSNR +def compute_psnr(gt_frame, gen_frame): + mse = np.mean((gt_frame - gen_frame) ** 2) + return 20 * np.log10(255.0 / math.sqrt(mse)) if mse > 0 else float('inf') + +# Helper function: compute SSIM using torchmetrics +def compute_ssim_metric(gt_tensor, gen_tensor, ssim_fn): + ssim_value = ssim_fn(gt_tensor, gen_tensor) + return ssim_value.item() + +# Helper function: compute LPIPS +def compute_lpips_metric(gt_tensor, gen_tensor, loss_fn): + return loss_fn(gt_tensor, gen_tensor).item() + +# Helper function: compute all metrics for a frame pair +def compute_metrics(gt_frame, gen_frame, loss_fn, ssim_fn, device): + metrics = {} + + # L1 Loss + metrics["L1"] = compute_l1_loss(gt_frame, gen_frame) + + # PSNR + metrics["PSNR"] = compute_psnr(gt_frame, gen_frame) + + # Convert frames to tensors + gt_tensor = torch.tensor(gt_frame, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device) / 255.0 + gen_tensor = torch.tensor(gen_frame, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device) / 255.0 + + # SSIM + metrics["SSIM"] = compute_ssim_metric(gt_tensor, gen_tensor, ssim_fn) + + # LPIPS + metrics["LPIPS"] = compute_lpips_metric(gt_tensor, gen_tensor, loss_fn) + + return metrics + +def compute_all_metrics(gt_video, gen_video, target_fps=30, debug=False): + # Step 1: Get total frame count and duration of videos + print(f"Loading ground truth video: {gt_video}") + print(f"Loading generated video: {gen_video}") + gt_frame_count, gt_duration = get_video_info(gt_video) + gen_frame_count, gen_duration = get_video_info(gen_video) + + print(f"Ground truth video: {gt_frame_count} frames, {gt_duration:.2f} seconds") + print(f"Generated video: {gen_frame_count} frames, {gen_duration:.2f} seconds") + + # Step 2: Determine target frame count based on FPS and duration + target_frame_count = min(int(target_fps * gt_duration), int(target_fps * gen_duration)) + print(f"Target frame count based on {target_fps} FPS: {target_frame_count}") + + # Step 3: Calculate frame indices to extract for both videos + cap_gt = cv2.VideoCapture(gt_video) + cap_gen = cv2.VideoCapture(gen_video) + + original_fps_gt = cap_gt.get(cv2.CAP_PROP_FPS) + original_fps_gen = cap_gen.get(cv2.CAP_PROP_FPS) + + frame_indices_gt = get_frame_indices(original_fps_gt, target_fps, gt_frame_count) + frame_indices_gen = get_frame_indices(original_fps_gen, target_fps, gen_frame_count) + + # Ensure both frame lists have the same number of frames + min_frames = min(len(frame_indices_gt), len(frame_indices_gen), target_frame_count) + frame_indices_gt = frame_indices_gt[:min_frames] + frame_indices_gen = frame_indices_gen[:min_frames] + + print(f"Number of frames to process: {min_frames}") + + # Step 4: Initialize LPIPS and SSIM once + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + loss_fn = lpips.LPIPS(net='alex').to(device) + ssim_fn = torch_ssim + + # Step 5: Read and process frames + metrics_list = [] + + # Preload frame indices into sets for faster access + frame_set_gt = set(frame_indices_gt) + frame_set_gen = set(frame_indices_gen) + + current_frame_gt = 0 + current_frame_gen = 0 + target_idx = 0 + + frames_gt = [] + print("Reading ground truth frames into memory...") + for desired_frame in tqdm(frame_indices_gt, desc="Loading GT frames"): + while current_frame_gt <= desired_frame: + ret, frame = cap_gt.read() + if not ret: + print(f"Warning: Reached end of ground truth video at frame {current_frame_gt}.") + frames_gt.append(None) + break + if current_frame_gt == desired_frame: + frames_gt.append(frame) + break + current_frame_gt += 1 + + # Reset for generated video + target_idx = 0 + current_frame_gen = 0 + frames_gen = [] + print("Reading generated video frames into memory...") + for desired_frame in tqdm(frame_indices_gen, desc="Loading Gen frames"): + while current_frame_gen <= desired_frame: + ret, frame = cap_gen.read() + if not ret: + # print(f"Warning: Reached end of generated video at frame {current_frame_gen}.") + frames_gen.append(None) + break + if current_frame_gen == desired_frame: + frames_gen.append(frame) + break + current_frame_gen += 1 + + cap_gt.release() + cap_gen.release() + + print(f"Total frames loaded: {len(frames_gt)}") + + # Step 6: Compute metrics for each frame pair + for i in tqdm(range(len(frames_gt)), desc="Computing metrics"): + gt_frame = frames_gt[i] + gen_frame = frames_gen[i] + + if gt_frame is None or gen_frame is None: + # print(f"Skipping frame {i} due to read error.") + continue + + metrics = compute_metrics(gt_frame, gen_frame, loss_fn, ssim_fn, device) + metrics_list.append(metrics) + + # Step 7: Calculate and print average metrics + metrics_array = np.array([list(m.values()) for m in metrics_list]) + metrics_names = list(metrics_list[0].keys()) + average_metrics = np.mean(metrics_array, axis=0) + print("\n=== Average Metrics ===") + for name, value in zip(metrics_names, average_metrics): + print(f"{name}: {value:.4f}") + + return average_metrics + +# Example usage +if __name__ == "__main__": + output_dir = "output_metrics" + gt_video = "E:\\Code\\DesktopCode\\LiveSpeechPortraits\\data\\Input\\May_short.mp4" + gen_video = "E:\\Code\\DesktopCode\\LiveSpeechPortraits\\results\\May\\May_short\\May_short.avi" + lists = compute_all_metrics(gt_video, gen_video) + print(lists) diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/.gitignore b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/.gitignore new file mode 100644 index 00000000..350ada00 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/.gitignore @@ -0,0 +1,45 @@ +# Compiled source # +################### +*.com +*.class +*.dll +*.exe +*.o +*.so +*.pyc + +# Packages # +############ +# it's better to unpack these files and commit the raw source +# git has its own built in compression methods +*.7z +*.dmg +*.gz +*.iso +*.jar +*.rar +*.tar +*.zip + +# Logs and databases # +###################### +*.log +*.sql +*.sqlite + +# OS generated files # +###################### +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Specific to this demo # +######################### +data/ +protos/ +utils/ +*.pth diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/LICENSE.md b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/LICENSE.md new file mode 100644 index 00000000..de4a5458 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/LICENSE.md @@ -0,0 +1,19 @@ +Copyright (c) 2016-present Joon Son Chung. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/README.md b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/README.md new file mode 100644 index 00000000..7da53541 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/README.md @@ -0,0 +1,59 @@ +# SyncNet + +This repository contains the demo for the audio-to-video synchronisation network (SyncNet). This network can be used for audio-visual synchronisation tasks including: +1. Removing temporal lags between the audio and visual streams in a video; +2. Determining who is speaking amongst multiple faces in a video. + +Please cite the paper below if you make use of the software. + +## Dependencies +``` +pip install -r requirements.txt +``` + +In addition, `ffmpeg` is required. + + +## Demo + +SyncNet demo: +``` +python demo_syncnet.py --videofile data/example.avi --tmp_dir /path/to/temp/directory +``` + +Check that this script returns: +``` +AV offset: 3 +Min dist: 5.353 +Confidence: 10.021 +``` + +Full pipeline: +``` +sh download_model.sh +python run_pipeline.py --videofile /path/to/video.mp4 --reference name_of_video --data_dir /path/to/output +python run_syncnet.py --videofile /path/to/video.mp4 --reference name_of_video --data_dir /path/to/output +python run_visualise.py --videofile /path/to/video.mp4 --reference name_of_video --data_dir /path/to/output +``` + +Outputs: +``` +$DATA_DIR/pycrop/$REFERENCE/*.avi - cropped face tracks +$DATA_DIR/pywork/$REFERENCE/offsets.txt - audio-video offset values +$DATA_DIR/pyavi/$REFERENCE/video_out.avi - output video (as shown below) +``` +

+ + +

+ +## Publications + +``` +@InProceedings{Chung16a, + author = "Chung, J.~S. and Zisserman, A.", + title = "Out of time: automated lip sync in the wild", + booktitle = "Workshop on Multi-view Lip-reading, ACCV", + year = "2016", +} +``` diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/SyncNetInstance.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/SyncNetInstance.py new file mode 100644 index 00000000..497d44fc --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/SyncNetInstance.py @@ -0,0 +1,208 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- +# Video 25 FPS, Audio 16000HZ + +import torch +import numpy +import time, pdb, argparse, subprocess, os, math, glob +import cv2 +import python_speech_features + +from scipy import signal +from scipy.io import wavfile +from SyncNetModel import * +from shutil import rmtree + + +# ==================== Get OFFSET ==================== + +def calc_pdist(feat1, feat2, vshift=10): + + win_size = vshift*2+1 + + feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift)) + + dists = [] + + for i in range(0,len(feat1)): + + dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:])) + + return dists + +# ==================== MAIN DEF ==================== + +class SyncNetInstance(torch.nn.Module): + + def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024): + super(SyncNetInstance, self).__init__(); + + self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda(); + + def evaluate(self, opt, videofile): + + self.__S__.eval(); + + # ========== ========== + # Convert files + # ========== ========== + + if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)): + rmtree(os.path.join(opt.tmp_dir,opt.reference)) + + os.makedirs(os.path.join(opt.tmp_dir,opt.reference)) + + command = ("ffmpeg -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg'))) + output = subprocess.call(command, shell=True, stdout=None) + + command = ("ffmpeg -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav'))) + output = subprocess.call(command, shell=True, stdout=None) + + # ========== ========== + # Load video + # ========== ========== + + images = [] + + flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg')) + flist.sort() + + for fname in flist: + images.append(cv2.imread(fname)) + + im = numpy.stack(images,axis=3) + im = numpy.expand_dims(im,axis=0) + im = numpy.transpose(im,(0,3,4,1,2)) + + imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) + + # ========== ========== + # Load audio + # ========== ========== + + sample_rate, audio = wavfile.read(os.path.join(opt.tmp_dir,opt.reference,'audio.wav')) + mfcc = zip(*python_speech_features.mfcc(audio,sample_rate)) + mfcc = numpy.stack([numpy.array(i) for i in mfcc]) + + cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0) + cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float()) + + # ========== ========== + # Check audio and video input length + # ========== ========== + + if (float(len(audio))/16000) != (float(len(images))/25) : + print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25)) + + min_length = min(len(images),math.floor(len(audio)/640)) + + # ========== ========== + # Generate video and audio feats + # ========== ========== + + lastframe = min_length-5 + im_feat = [] + cc_feat = [] + + tS = time.time() + for i in range(0,lastframe,opt.batch_size): + + im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + im_in = torch.cat(im_batch,0) + im_out = self.__S__.forward_lip(im_in.cuda()); + im_feat.append(im_out.data.cpu()) + + cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + cc_in = torch.cat(cc_batch,0) + cc_out = self.__S__.forward_aud(cc_in.cuda()) + cc_feat.append(cc_out.data.cpu()) + + im_feat = torch.cat(im_feat,0) + cc_feat = torch.cat(cc_feat,0) + + # ========== ========== + # Compute offset + # ========== ========== + + print('Compute time %.3f sec.' % (time.time()-tS)) + + dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift) + mdist = torch.mean(torch.stack(dists,1),1) + + minval, minidx = torch.min(mdist,0) + + offset = opt.vshift-minidx + conf = torch.median(mdist) - minval + + fdist = numpy.stack([dist[minidx].numpy() for dist in dists]) + # fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15) + fconf = torch.median(mdist).numpy() - fdist + fconfm = signal.medfilt(fconf,kernel_size=9) + + numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format}) + print('Framewise conf: ') + print(fconfm) + print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf)) + + dists_npy = numpy.array([ dist.numpy() for dist in dists ]) + return offset.numpy(), conf.numpy(), dists_npy + + def extract_feature(self, opt, videofile): + + self.__S__.eval(); + + # ========== ========== + # Load video + # ========== ========== + cap = cv2.VideoCapture(videofile) + + frame_num = 1; + images = [] + while frame_num: + frame_num += 1 + ret, image = cap.read() + if ret == 0: + break + + images.append(image) + + im = numpy.stack(images,axis=3) + im = numpy.expand_dims(im,axis=0) + im = numpy.transpose(im,(0,3,4,1,2)) + + imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) + + # ========== ========== + # Generate video feats + # ========== ========== + + lastframe = len(images)-4 + im_feat = [] + + tS = time.time() + for i in range(0,lastframe,opt.batch_size): + + im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + im_in = torch.cat(im_batch,0) + im_out = self.__S__.forward_lipfeat(im_in.cuda()); + im_feat.append(im_out.data.cpu()) + + im_feat = torch.cat(im_feat,0) + + # ========== ========== + # Compute offset + # ========== ========== + + print('Compute time %.3f sec.' % (time.time()-tS)) + + return im_feat + + + def loadParameters(self, path): + loaded_state = torch.load(path, map_location=lambda storage, loc: storage); + + self_state = self.__S__.state_dict(); + + for name, param in loaded_state.items(): + + self_state[name].copy_(param); diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/SyncNetInstance_calc_scores.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/SyncNetInstance_calc_scores.py new file mode 100644 index 00000000..64906e25 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/SyncNetInstance_calc_scores.py @@ -0,0 +1,210 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- +# Video 25 FPS, Audio 16000HZ + +import torch +import numpy +import time, pdb, argparse, subprocess, os, math, glob +import cv2 +import python_speech_features + +from scipy import signal +from scipy.io import wavfile +from SyncNetModel import * +from shutil import rmtree + + +# ==================== Get OFFSET ==================== + +def calc_pdist(feat1, feat2, vshift=10): + + win_size = vshift*2+1 + + feat2p = torch.nn.functional.pad(feat2,(0,0,vshift,vshift)) + + dists = [] + + for i in range(0,len(feat1)): + + dists.append(torch.nn.functional.pairwise_distance(feat1[[i],:].repeat(win_size, 1), feat2p[i:i+win_size,:])) + + return dists + +# ==================== MAIN DEF ==================== + +class SyncNetInstance(torch.nn.Module): + + def __init__(self, dropout = 0, num_layers_in_fc_layers = 1024): + super(SyncNetInstance, self).__init__(); + + self.__S__ = S(num_layers_in_fc_layers = num_layers_in_fc_layers).cuda(); + + def evaluate(self, opt, videofile): + + self.__S__.eval(); + + # ========== ========== + # Convert files + # ========== ========== + + if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)): + rmtree(os.path.join(opt.tmp_dir,opt.reference)) + + os.makedirs(os.path.join(opt.tmp_dir,opt.reference)) + + command = ("ffmpeg -loglevel error -y -i %s -threads 1 -f image2 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'%06d.jpg'))) + output = subprocess.call(command, shell=True, stdout=None) + + command = ("ffmpeg -loglevel error -y -i %s -async 1 -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (videofile,os.path.join(opt.tmp_dir,opt.reference,'audio.wav'))) + output = subprocess.call(command, shell=True, stdout=None) + + # ========== ========== + # Load video + # ========== ========== + + images = [] + + flist = glob.glob(os.path.join(opt.tmp_dir,opt.reference,'*.jpg')) + flist.sort() + + for fname in flist: + img_input = cv2.imread(fname) + img_input = cv2.resize(img_input, (224,224)) #HARD CODED, CHANGE BEFORE RELEASE + images.append(img_input) + + im = numpy.stack(images,axis=3) + im = numpy.expand_dims(im,axis=0) + im = numpy.transpose(im,(0,3,4,1,2)) + + imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) + + # ========== ========== + # Load audio + # ========== ========== + + sample_rate, audio = wavfile.read(os.path.join(opt.tmp_dir,opt.reference,'audio.wav')) + mfcc = zip(*python_speech_features.mfcc(audio,sample_rate)) + mfcc = numpy.stack([numpy.array(i) for i in mfcc]) + + cc = numpy.expand_dims(numpy.expand_dims(mfcc,axis=0),axis=0) + cct = torch.autograd.Variable(torch.from_numpy(cc.astype(float)).float()) + + # ========== ========== + # Check audio and video input length + # ========== ========== + + #if (float(len(audio))/16000) != (float(len(images))/25) : + # print("WARNING: Audio (%.4fs) and video (%.4fs) lengths are different."%(float(len(audio))/16000,float(len(images))/25)) + + min_length = min(len(images),math.floor(len(audio)/640)) + + # ========== ========== + # Generate video and audio feats + # ========== ========== + + lastframe = min_length-5 + im_feat = [] + cc_feat = [] + + tS = time.time() + for i in range(0,lastframe,opt.batch_size): + + im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + im_in = torch.cat(im_batch,0) + im_out = self.__S__.forward_lip(im_in.cuda()); + im_feat.append(im_out.data.cpu()) + + cc_batch = [ cct[:,:,:,vframe*4:vframe*4+20] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + cc_in = torch.cat(cc_batch,0) + cc_out = self.__S__.forward_aud(cc_in.cuda()) + cc_feat.append(cc_out.data.cpu()) + + im_feat = torch.cat(im_feat,0) + cc_feat = torch.cat(cc_feat,0) + + # ========== ========== + # Compute offset + # ========== ========== + + #print('Compute time %.3f sec.' % (time.time()-tS)) + + dists = calc_pdist(im_feat,cc_feat,vshift=opt.vshift) + mdist = torch.mean(torch.stack(dists,1),1) + + minval, minidx = torch.min(mdist,0) + + offset = opt.vshift-minidx + conf = torch.median(mdist) - minval + + fdist = numpy.stack([dist[minidx].numpy() for dist in dists]) + # fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=15) + fconf = torch.median(mdist).numpy() - fdist + fconfm = signal.medfilt(fconf,kernel_size=9) + + numpy.set_printoptions(formatter={'float': '{: 0.3f}'.format}) + #print('Framewise conf: ') + #print(fconfm) + #print('AV offset: \t%d \nMin dist: \t%.3f\nConfidence: \t%.3f' % (offset,minval,conf)) + + dists_npy = numpy.array([ dist.numpy() for dist in dists ]) + return offset.numpy(), conf.numpy(), minval.numpy() + + def extract_feature(self, opt, videofile): + + self.__S__.eval(); + + # ========== ========== + # Load video + # ========== ========== + cap = cv2.VideoCapture(videofile) + + frame_num = 1; + images = [] + while frame_num: + frame_num += 1 + ret, image = cap.read() + if ret == 0: + break + + images.append(image) + + im = numpy.stack(images,axis=3) + im = numpy.expand_dims(im,axis=0) + im = numpy.transpose(im,(0,3,4,1,2)) + + imtv = torch.autograd.Variable(torch.from_numpy(im.astype(float)).float()) + + # ========== ========== + # Generate video feats + # ========== ========== + + lastframe = len(images)-4 + im_feat = [] + + tS = time.time() + for i in range(0,lastframe,opt.batch_size): + + im_batch = [ imtv[:,:,vframe:vframe+5,:,:] for vframe in range(i,min(lastframe,i+opt.batch_size)) ] + im_in = torch.cat(im_batch,0) + im_out = self.__S__.forward_lipfeat(im_in.cuda()); + im_feat.append(im_out.data.cpu()) + + im_feat = torch.cat(im_feat,0) + + # ========== ========== + # Compute offset + # ========== ========== + + print('Compute time %.3f sec.' % (time.time()-tS)) + + return im_feat + + + def loadParameters(self, path): + loaded_state = torch.load(path, map_location=lambda storage, loc: storage); + + self_state = self.__S__.state_dict(); + + for name, param in loaded_state.items(): + + self_state[name].copy_(param); diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/SyncNetModel.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/SyncNetModel.py new file mode 100644 index 00000000..c21ce25c --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/SyncNetModel.py @@ -0,0 +1,117 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import torch +import torch.nn as nn + +def save(model, filename): + with open(filename, "wb") as f: + torch.save(model, f); + print("%s saved."%filename); + +def load(filename): + net = torch.load(filename) + return net; + +class S(nn.Module): + def __init__(self, num_layers_in_fc_layers = 1024): + super(S, self).__init__(); + + self.__nFeatures__ = 24; + self.__nChs__ = 32; + self.__midChs__ = 32; + + self.netcnnaud = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(1,1), stride=(1,1)), + + nn.Conv2d(64, 192, kernel_size=(3,3), stride=(1,1), padding=(1,1)), + nn.BatchNorm2d(192), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(3,3), stride=(1,2)), + + nn.Conv2d(192, 384, kernel_size=(3,3), padding=(1,1)), + nn.BatchNorm2d(384), + nn.ReLU(inplace=True), + + nn.Conv2d(384, 256, kernel_size=(3,3), padding=(1,1)), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + + nn.Conv2d(256, 256, kernel_size=(3,3), padding=(1,1)), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(3,3), stride=(2,2)), + + nn.Conv2d(256, 512, kernel_size=(5,4), padding=(0,0)), + nn.BatchNorm2d(512), + nn.ReLU(), + ); + + self.netfcaud = nn.Sequential( + nn.Linear(512, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, num_layers_in_fc_layers), + ); + + self.netfclip = nn.Sequential( + nn.Linear(512, 512), + nn.BatchNorm1d(512), + nn.ReLU(), + nn.Linear(512, num_layers_in_fc_layers), + ); + + self.netcnnlip = nn.Sequential( + nn.Conv3d(3, 96, kernel_size=(5,7,7), stride=(1,2,2), padding=0), + nn.BatchNorm3d(96), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)), + + nn.Conv3d(96, 256, kernel_size=(1,5,5), stride=(1,2,2), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)), + + nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + + nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + + nn.Conv3d(256, 256, kernel_size=(1,3,3), padding=(0,1,1)), + nn.BatchNorm3d(256), + nn.ReLU(inplace=True), + nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2)), + + nn.Conv3d(256, 512, kernel_size=(1,6,6), padding=0), + nn.BatchNorm3d(512), + nn.ReLU(inplace=True), + ); + + def forward_aud(self, x): + + mid = self.netcnnaud(x); # N x ch x 24 x M + mid = mid.view((mid.size()[0], -1)); # N x (ch x 24) + out = self.netfcaud(mid); + + return out; + + def forward_lip(self, x): + + mid = self.netcnnlip(x); + mid = mid.view((mid.size()[0], -1)); # N x (ch x 24) + out = self.netfclip(mid); + + return out; + + def forward_lipfeat(self, x): + + mid = self.netcnnlip(x); + out = mid.view((mid.size()[0], -1)); # N x (ch x 24) + + return out; \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/all_scores.txt b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/all_scores.txt new file mode 100644 index 00000000..31d7a474 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/all_scores.txt @@ -0,0 +1 @@ +6.296045 9.87826 diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/calculate_scores_LRS.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/calculate_scores_LRS.py new file mode 100644 index 00000000..eda02b8f --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/calculate_scores_LRS.py @@ -0,0 +1,53 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import time, pdb, argparse, subprocess +import glob +import os +from tqdm import tqdm + +from SyncNetInstance_calc_scores import * + +# ==================== LOAD PARAMS ==================== + + +parser = argparse.ArgumentParser(description = "SyncNet"); + +parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help=''); +parser.add_argument('--batch_size', type=int, default='20', help=''); +parser.add_argument('--vshift', type=int, default='15', help=''); +parser.add_argument('--data_root', type=str, required=True, help=''); +parser.add_argument('--tmp_dir', type=str, default="data/work/pytmp", help=''); +parser.add_argument('--reference', type=str, default="demo", help=''); + +opt = parser.parse_args(); + + +# ==================== RUN EVALUATION ==================== + +s = SyncNetInstance(); + +s.loadParameters(opt.initial_model); +#print("Model %s loaded."%opt.initial_model); +path = os.path.join(opt.data_root, "*.mp4") + +all_videos = glob.glob(path) + +prog_bar = tqdm(range(len(all_videos))) +avg_confidence = 0. +avg_min_distance = 0. + + +for videofile_idx in prog_bar: + videofile = all_videos[videofile_idx] + offset, confidence, min_distance = s.evaluate(opt, videofile=videofile) + avg_confidence += confidence + avg_min_distance += min_distance + prog_bar.set_description('Avg Confidence: {}, Avg Minimum Dist: {}'.format(round(avg_confidence / (videofile_idx + 1), 3), round(avg_min_distance / (videofile_idx + 1), 3))) + prog_bar.refresh() + +print ('Average Confidence: {}'.format(avg_confidence/len(all_videos))) +print ('Average Minimum Distance: {}'.format(avg_min_distance/len(all_videos))) + + + diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/calculate_scores_real_videos.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/calculate_scores_real_videos.py new file mode 100644 index 00000000..09622584 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/calculate_scores_real_videos.py @@ -0,0 +1,45 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import time, pdb, argparse, subprocess, pickle, os, gzip, glob + +from SyncNetInstance_calc_scores import * + +# ==================== PARSE ARGUMENT ==================== + +parser = argparse.ArgumentParser(description = "SyncNet"); +parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help=''); +parser.add_argument('--batch_size', type=int, default='20', help=''); +parser.add_argument('--vshift', type=int, default='15', help=''); +parser.add_argument('--data_dir', type=str, default='data/work', help=''); +parser.add_argument('--videofile', type=str, default='', help=''); +parser.add_argument('--reference', type=str, default='', help=''); +opt = parser.parse_args(); + +setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi')) +setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp')) +setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork')) +setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop')) + + +# ==================== LOAD MODEL AND FILE LIST ==================== + +s = SyncNetInstance(); + +s.loadParameters(opt.initial_model); +#print("Model %s loaded."%opt.initial_model); + +flist = glob.glob(os.path.join(opt.crop_dir,opt.reference,'0*.avi')) +flist.sort() + +# ==================== GET OFFSETS ==================== + +dists = [] +for idx, fname in enumerate(flist): + offset, conf, dist = s.evaluate(opt,videofile=fname) + print (str(dist)+" "+str(conf)) + +# ==================== PRINT RESULTS TO FILE ==================== + +#with open(os.path.join(opt.work_dir,opt.reference,'activesd.pckl'), 'wb') as fil: +# pickle.dump(dists, fil) diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/demo_feature.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/demo_feature.py new file mode 100644 index 00000000..e3bd290e --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/demo_feature.py @@ -0,0 +1,32 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import time, pdb, argparse, subprocess + +from SyncNetInstance import * + +# ==================== LOAD PARAMS ==================== + + +parser = argparse.ArgumentParser(description = "SyncNet"); + +parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help=''); +parser.add_argument('--batch_size', type=int, default='20', help=''); +parser.add_argument('--vshift', type=int, default='15', help=''); +parser.add_argument('--videofile', type=str, default="data/example.avi", help=''); +parser.add_argument('--tmp_dir', type=str, default="data", help=''); +parser.add_argument('--save_as', type=str, default="data/features.pt", help=''); + +opt = parser.parse_args(); + + +# ==================== RUN EVALUATION ==================== + +s = SyncNetInstance(); + +s.loadParameters(opt.initial_model); +print("Model %s loaded."%opt.initial_model); + +feats = s.extract_feature(opt, videofile=opt.videofile) + +torch.save(feats, opt.save_as) diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/demo_syncnet.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/demo_syncnet.py new file mode 100644 index 00000000..01c25a6f --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/demo_syncnet.py @@ -0,0 +1,30 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import time, pdb, argparse, subprocess + +from SyncNetInstance import * + +# ==================== LOAD PARAMS ==================== + + +parser = argparse.ArgumentParser(description = "SyncNet"); + +parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help=''); +parser.add_argument('--batch_size', type=int, default='20', help=''); +parser.add_argument('--vshift', type=int, default='15', help=''); +parser.add_argument('--videofile', type=str, default="data/example.avi", help=''); +parser.add_argument('--tmp_dir', type=str, default="data/work/pytmp", help=''); +parser.add_argument('--reference', type=str, default="demo", help=''); + +opt = parser.parse_args(); + + +# ==================== RUN EVALUATION ==================== + +s = SyncNetInstance(); + +s.loadParameters(opt.initial_model); +print("Model %s loaded."%opt.initial_model); + +s.evaluate(opt, videofile=opt.videofile) diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/README.md b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/README.md new file mode 100644 index 00000000..f5a8d4fe --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/README.md @@ -0,0 +1,3 @@ +# Face detector + +This face detector is adapted from `https://github.com/cs-giung/face-detection-pytorch`. diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/__init__.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/__init__.py new file mode 100644 index 00000000..059d49bf --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/__init__.py @@ -0,0 +1 @@ +from .s3fd import S3FD \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/s3fd/__init__.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/s3fd/__init__.py new file mode 100644 index 00000000..d7f35e05 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/s3fd/__init__.py @@ -0,0 +1,61 @@ +import time +import numpy as np +import cv2 +import torch +from torchvision import transforms +from .nets import S3FDNet +from .box_utils import nms_ + +PATH_WEIGHT = './detectors/s3fd/weights/sfd_face.pth' +img_mean = np.array([104., 117., 123.])[:, np.newaxis, np.newaxis].astype('float32') + + +class S3FD(): + + def __init__(self, device='cuda'): + + tstamp = time.time() + self.device = device + + print('[S3FD] loading with', self.device) + self.net = S3FDNet(device=self.device).to(self.device) + state_dict = torch.load(PATH_WEIGHT, map_location=self.device) + self.net.load_state_dict(state_dict) + self.net.eval() + print('[S3FD] finished loading (%.4f sec)' % (time.time() - tstamp)) + + def detect_faces(self, image, conf_th=0.8, scales=[1]): + + w, h = image.shape[1], image.shape[0] + + bboxes = np.empty(shape=(0, 5)) + + with torch.no_grad(): + for s in scales: + scaled_img = cv2.resize(image, dsize=(0, 0), fx=s, fy=s, interpolation=cv2.INTER_LINEAR) + + scaled_img = np.swapaxes(scaled_img, 1, 2) + scaled_img = np.swapaxes(scaled_img, 1, 0) + scaled_img = scaled_img[[2, 1, 0], :, :] + scaled_img = scaled_img.astype('float32') + scaled_img -= img_mean + scaled_img = scaled_img[[2, 1, 0], :, :] + x = torch.from_numpy(scaled_img).unsqueeze(0).to(self.device) + y = self.net(x) + + detections = y.data + scale = torch.Tensor([w, h, w, h]) + + for i in range(detections.size(1)): + j = 0 + while detections[0, i, j, 0] > conf_th: + score = detections[0, i, j, 0] + pt = (detections[0, i, j, 1:] * scale).cpu().numpy() + bbox = (pt[0], pt[1], pt[2], pt[3], score) + bboxes = np.vstack((bboxes, bbox)) + j += 1 + + keep = nms_(bboxes, 0.1) + bboxes = bboxes[keep] + + return bboxes diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/s3fd/box_utils.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/s3fd/box_utils.py new file mode 100644 index 00000000..0779bcd5 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/s3fd/box_utils.py @@ -0,0 +1,217 @@ +import numpy as np +from itertools import product as product +import torch +from torch.autograd import Function + + +def nms_(dets, thresh): + """ + Courtesy of Ross Girshick + [https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py] + """ + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1) * (y2 - y1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(int(i)) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return np.array(keep).astype(np.int) + + +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = scores.new(scores.size(0)).zero_().long() + if boxes.numel() == 0: + return keep, 0 + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w * h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter / union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count + + +class Detect(object): + + def __init__(self, num_classes=2, + top_k=750, nms_thresh=0.3, conf_thresh=0.05, + variance=[0.1, 0.2], nms_top_k=5000): + + self.num_classes = num_classes + self.top_k = top_k + self.nms_thresh = nms_thresh + self.conf_thresh = conf_thresh + self.variance = variance + self.nms_top_k = nms_top_k + + def forward(self, loc_data, conf_data, prior_data): + + num = loc_data.size(0) + num_priors = prior_data.size(0) + + conf_preds = conf_data.view(num, num_priors, self.num_classes).transpose(2, 1) + batch_priors = prior_data.view(-1, num_priors, 4).expand(num, num_priors, 4) + batch_priors = batch_priors.contiguous().view(-1, 4) + + decoded_boxes = decode(loc_data.view(-1, 4), batch_priors, self.variance) + decoded_boxes = decoded_boxes.view(num, num_priors, 4) + + output = torch.zeros(num, self.num_classes, self.top_k, 5) + + for i in range(num): + boxes = decoded_boxes[i].clone() + conf_scores = conf_preds[i].clone() + + for cl in range(1, self.num_classes): + c_mask = conf_scores[cl].gt(self.conf_thresh) + scores = conf_scores[cl][c_mask] + + if scores.dim() == 0: + continue + l_mask = c_mask.unsqueeze(1).expand_as(boxes) + boxes_ = boxes[l_mask].view(-1, 4) + ids, count = nms(boxes_, scores, self.nms_thresh, self.nms_top_k) + count = count if count < self.top_k else self.top_k + + output[i, cl, :count] = torch.cat((scores[ids[:count]].unsqueeze(1), boxes_[ids[:count]]), 1) + + return output + + +class PriorBox(object): + + def __init__(self, input_size, feature_maps, + variance=[0.1, 0.2], + min_sizes=[16, 32, 64, 128, 256, 512], + steps=[4, 8, 16, 32, 64, 128], + clip=False): + + super(PriorBox, self).__init__() + + self.imh = input_size[0] + self.imw = input_size[1] + self.feature_maps = feature_maps + + self.variance = variance + self.min_sizes = min_sizes + self.steps = steps + self.clip = clip + + def forward(self): + mean = [] + for k, fmap in enumerate(self.feature_maps): + feath = fmap[0] + featw = fmap[1] + for i, j in product(range(feath), range(featw)): + f_kw = self.imw / self.steps[k] + f_kh = self.imh / self.steps[k] + + cx = (j + 0.5) / f_kw + cy = (i + 0.5) / f_kh + + s_kw = self.min_sizes[k] / self.imw + s_kh = self.min_sizes[k] / self.imh + + mean += [cx, cy, s_kw, s_kh] + + output = torch.FloatTensor(mean).view(-1, 4) + + if self.clip: + output.clamp_(max=1, min=0) + + return output diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/s3fd/nets.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/s3fd/nets.py new file mode 100644 index 00000000..85b5c82c --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/detectors/s3fd/nets.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from .box_utils import Detect, PriorBox + + +class L2Norm(nn.Module): + + def __init__(self, n_channels, scale): + super(L2Norm, self).__init__() + self.n_channels = n_channels + self.gamma = scale or None + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.reset_parameters() + + def reset_parameters(self): + init.constant_(self.weight, self.gamma) + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps + x = torch.div(x, norm) + out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x + return out + + +class S3FDNet(nn.Module): + + def __init__(self, device='cuda'): + super(S3FDNet, self).__init__() + self.device = device + + self.vgg = nn.ModuleList([ + nn.Conv2d(3, 64, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, 64, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + nn.Conv2d(64, 128, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + nn.Conv2d(128, 256, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2, ceil_mode=True), + + nn.Conv2d(256, 512, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + nn.Conv2d(512, 512, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, 1, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), + + nn.Conv2d(512, 1024, 3, 1, padding=6, dilation=6), + nn.ReLU(inplace=True), + nn.Conv2d(1024, 1024, 1, 1), + nn.ReLU(inplace=True), + ]) + + self.L2Norm3_3 = L2Norm(256, 10) + self.L2Norm4_3 = L2Norm(512, 8) + self.L2Norm5_3 = L2Norm(512, 5) + + self.extras = nn.ModuleList([ + nn.Conv2d(1024, 256, 1, 1), + nn.Conv2d(256, 512, 3, 2, padding=1), + nn.Conv2d(512, 128, 1, 1), + nn.Conv2d(128, 256, 3, 2, padding=1), + ]) + + self.loc = nn.ModuleList([ + nn.Conv2d(256, 4, 3, 1, padding=1), + nn.Conv2d(512, 4, 3, 1, padding=1), + nn.Conv2d(512, 4, 3, 1, padding=1), + nn.Conv2d(1024, 4, 3, 1, padding=1), + nn.Conv2d(512, 4, 3, 1, padding=1), + nn.Conv2d(256, 4, 3, 1, padding=1), + ]) + + self.conf = nn.ModuleList([ + nn.Conv2d(256, 4, 3, 1, padding=1), + nn.Conv2d(512, 2, 3, 1, padding=1), + nn.Conv2d(512, 2, 3, 1, padding=1), + nn.Conv2d(1024, 2, 3, 1, padding=1), + nn.Conv2d(512, 2, 3, 1, padding=1), + nn.Conv2d(256, 2, 3, 1, padding=1), + ]) + + self.softmax = nn.Softmax(dim=-1) + self.detect = Detect() + + def forward(self, x): + size = x.size()[2:] + sources = list() + loc = list() + conf = list() + + for k in range(16): + x = self.vgg[k](x) + s = self.L2Norm3_3(x) + sources.append(s) + + for k in range(16, 23): + x = self.vgg[k](x) + s = self.L2Norm4_3(x) + sources.append(s) + + for k in range(23, 30): + x = self.vgg[k](x) + s = self.L2Norm5_3(x) + sources.append(s) + + for k in range(30, len(self.vgg)): + x = self.vgg[k](x) + sources.append(x) + + # apply extra layers and cache source layer outputs + for k, v in enumerate(self.extras): + x = F.relu(v(x), inplace=True) + if k % 2 == 1: + sources.append(x) + + # apply multibox head to source layers + loc_x = self.loc[0](sources[0]) + conf_x = self.conf[0](sources[0]) + + max_conf, _ = torch.max(conf_x[:, 0:3, :, :], dim=1, keepdim=True) + conf_x = torch.cat((max_conf, conf_x[:, 3:, :, :]), dim=1) + + loc.append(loc_x.permute(0, 2, 3, 1).contiguous()) + conf.append(conf_x.permute(0, 2, 3, 1).contiguous()) + + for i in range(1, len(sources)): + x = sources[i] + conf.append(self.conf[i](x).permute(0, 2, 3, 1).contiguous()) + loc.append(self.loc[i](x).permute(0, 2, 3, 1).contiguous()) + + features_maps = [] + for i in range(len(loc)): + feat = [] + feat += [loc[i].size(1), loc[i].size(2)] + features_maps += [feat] + + loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) + conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) + + with torch.no_grad(): + self.priorbox = PriorBox(size, features_maps) + self.priors = self.priorbox.forward() + + output = self.detect.forward( + loc.view(loc.size(0), -1, 4), + self.softmax(conf.view(conf.size(0), -1, 2)), + self.priors.type(type(x.data)).to(self.device) + ) + + return output diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/download_model.sh b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/download_model.sh new file mode 100644 index 00000000..3e3a9dc2 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/download_model.sh @@ -0,0 +1,9 @@ +# SyncNet model + +mkdir data +wget http://www.robots.ox.ac.uk/~vgg/software/lipsync/data/syncnet_v2.model -O data/syncnet_v2.model +wget http://www.robots.ox.ac.uk/~vgg/software/lipsync/data/example.avi -O data/example.avi + +# For the pre-processing pipeline +mkdir detectors/s3fd/weights +wget https://www.robots.ox.ac.uk/~vgg/software/lipsync/data/sfd_face.pth -O detectors/s3fd/weights/sfd_face.pth \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/how-to-run b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/how-to-run new file mode 100644 index 00000000..f578bb2b --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/how-to-run @@ -0,0 +1,2 @@ +python run_pipeline.py --videofile /path/to/your/video --reference wav2lip --data_dir tmp_dir +python calculate_scores_real_videos.py --videofile /path/to/you/video --reference wav2lip --data_dir tmp_dir >> all_scores.txt \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/img/ex1.jpg b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/img/ex1.jpg new file mode 100644 index 00000000..b20b57e1 Binary files /dev/null and b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/img/ex1.jpg differ diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/img/ex2.jpg b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/img/ex2.jpg new file mode 100644 index 00000000..851402cc Binary files /dev/null and b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/img/ex2.jpg differ diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/judge_lse.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/judge_lse.py new file mode 100644 index 00000000..cce0f7b9 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/judge_lse.py @@ -0,0 +1,52 @@ +import shutil +import subprocess +import sys +import os + +def run_commands(video_file_path): + current_dir = os.path.dirname(os.path.abspath(__file__)) + os.chdir(current_dir) + + # 定义要执行的命令 + command1 = [ + "python", "run_pipeline.py", + "--videofile", video_file_path, + "--reference", "wav2lip", + "--data_dir", "tmp_dir" + ] + + command2 = [ + "python", "calculate_scores_real_videos.py", + "--videofile", video_file_path, + "--reference", "wav2lip", + "--data_dir", "tmp_dir" + ] + + try: + # 预处理 + subprocess.run(command1, check=True, stdout=subprocess.DEVNULL) + + # 评估 + if os.path.exists("all_scores.txt"): + os.remove("all_scores.txt") + with open("all_scores.txt", "a") as score_file: + subprocess.run(command2, check=True, stdout=score_file) + + # 删除 tmp_dir 目录 + if os.path.exists("tmp_dir"): + shutil.rmtree("tmp_dir") + + except subprocess.CalledProcessError as e: + print(f"{e}") + sys.exit(1) + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python ") + sys.exit(1) + + # 获取命令行参数 + video_file_path = sys.argv[1] + + # 调用函数执行命令 + run_commands(video_file_path) diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/requirements.txt b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/requirements.txt new file mode 100644 index 00000000..89197409 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/requirements.txt @@ -0,0 +1,7 @@ +torch>=1.4.0 +torchvision>=0.5.0 +numpy>=1.18.1 +scipy>=1.2.1 +scenedetect==0.5.1 +opencv-contrib-python +python_speech_features diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/run_pipeline.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/run_pipeline.py new file mode 100644 index 00000000..f5fc22e0 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/run_pipeline.py @@ -0,0 +1,322 @@ +#!/usr/bin/python + +import sys, time, os, pdb, argparse, pickle, subprocess, glob, cv2 +import numpy as np +from shutil import rmtree + +import scenedetect +from scenedetect.video_manager import VideoManager +from scenedetect.scene_manager import SceneManager +from scenedetect.frame_timecode import FrameTimecode +from scenedetect.stats_manager import StatsManager +from scenedetect.detectors import ContentDetector + +from scipy.interpolate import interp1d +from scipy.io import wavfile +from scipy import signal + +from detectors import S3FD + +# ========== ========== ========== ========== +# # PARSE ARGS +# ========== ========== ========== ========== + +parser = argparse.ArgumentParser(description = "FaceTracker"); +parser.add_argument('--data_dir', type=str, default='data/work', help='Output direcotry'); +parser.add_argument('--videofile', type=str, default='', help='Input video file'); +parser.add_argument('--reference', type=str, default='', help='Video reference'); +parser.add_argument('--facedet_scale', type=float, default=0.25, help='Scale factor for face detection'); +parser.add_argument('--crop_scale', type=float, default=0.40, help='Scale bounding box'); +parser.add_argument('--min_track', type=int, default=100, help='Minimum facetrack duration'); +parser.add_argument('--frame_rate', type=int, default=25, help='Frame rate'); +parser.add_argument('--num_failed_det', type=int, default=25, help='Number of missed detections allowed before tracking is stopped'); +parser.add_argument('--min_face_size', type=int, default=100, help='Minimum face size in pixels'); +opt = parser.parse_args(); + +setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi')) +setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp')) +setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork')) +setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop')) +setattr(opt,'frames_dir',os.path.join(opt.data_dir,'pyframes')) + +# ========== ========== ========== ========== +# # IOU FUNCTION +# ========== ========== ========== ========== + +def bb_intersection_over_union(boxA, boxB): + + xA = max(boxA[0], boxB[0]) + yA = max(boxA[1], boxB[1]) + xB = min(boxA[2], boxB[2]) + yB = min(boxA[3], boxB[3]) + + interArea = max(0, xB - xA) * max(0, yB - yA) + + boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) + boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) + + iou = interArea / float(boxAArea + boxBArea - interArea) + + return iou + +# ========== ========== ========== ========== +# # FACE TRACKING +# ========== ========== ========== ========== + +def track_shot(opt,scenefaces): + + iouThres = 0.5 # Minimum IOU between consecutive face detections + tracks = [] + + while True: + track = [] + for framefaces in scenefaces: + for face in framefaces: + if track == []: + track.append(face) + framefaces.remove(face) + elif face['frame'] - track[-1]['frame'] <= opt.num_failed_det: + iou = bb_intersection_over_union(face['bbox'], track[-1]['bbox']) + if iou > iouThres: + track.append(face) + framefaces.remove(face) + continue + else: + break + + if track == []: + break + elif len(track) > opt.min_track: + + framenum = np.array([ f['frame'] for f in track ]) + bboxes = np.array([np.array(f['bbox']) for f in track]) + + frame_i = np.arange(framenum[0],framenum[-1]+1) + + bboxes_i = [] + for ij in range(0,4): + interpfn = interp1d(framenum, bboxes[:,ij]) + bboxes_i.append(interpfn(frame_i)) + bboxes_i = np.stack(bboxes_i, axis=1) + + if max(np.mean(bboxes_i[:,2]-bboxes_i[:,0]), np.mean(bboxes_i[:,3]-bboxes_i[:,1])) > opt.min_face_size: + tracks.append({'frame':frame_i,'bbox':bboxes_i}) + + return tracks + +# ========== ========== ========== ========== +# # VIDEO CROP AND SAVE +# ========== ========== ========== ========== + +def crop_video(opt,track,cropfile): + + flist = glob.glob(os.path.join(opt.frames_dir,opt.reference,'*.jpg')) + flist.sort() + + fourcc = cv2.VideoWriter_fourcc(*'XVID') + vOut = cv2.VideoWriter(cropfile+'t.avi', fourcc, opt.frame_rate, (224,224)) + + dets = {'x':[], 'y':[], 's':[]} + + for det in track['bbox']: + + dets['s'].append(max((det[3]-det[1]),(det[2]-det[0]))/2) + dets['y'].append((det[1]+det[3])/2) # crop center x + dets['x'].append((det[0]+det[2])/2) # crop center y + + # Smooth detections + dets['s'] = signal.medfilt(dets['s'],kernel_size=13) + dets['x'] = signal.medfilt(dets['x'],kernel_size=13) + dets['y'] = signal.medfilt(dets['y'],kernel_size=13) + + for fidx, frame in enumerate(track['frame']): + + cs = opt.crop_scale + + bs = dets['s'][fidx] # Detection box size + bsi = int(bs*(1+2*cs)) # Pad videos by this amount + + image = cv2.imread(flist[frame]) + + frame = np.pad(image,((bsi,bsi),(bsi,bsi),(0,0)), 'constant', constant_values=(110,110)) + my = dets['y'][fidx]+bsi # BBox center Y + mx = dets['x'][fidx]+bsi # BBox center X + + face = frame[int(my-bs):int(my+bs*(1+2*cs)),int(mx-bs*(1+cs)):int(mx+bs*(1+cs))] + + vOut.write(cv2.resize(face,(224,224))) + + audiotmp = os.path.join(opt.tmp_dir,opt.reference,'audio.wav') + audiostart = (track['frame'][0])/opt.frame_rate + audioend = (track['frame'][-1]+1)/opt.frame_rate + + vOut.release() + + # ========== CROP AUDIO FILE ========== + + command = ("ffmpeg -y -i %s -ss %.3f -to %.3f %s" % (os.path.join(opt.avi_dir,opt.reference,'audio.wav'),audiostart,audioend,audiotmp)) + output = subprocess.call(command, shell=True, stdout=None) + + if output != 0: + pdb.set_trace() + + sample_rate, audio = wavfile.read(audiotmp) + + # ========== COMBINE AUDIO AND VIDEO FILES ========== + + command = ("ffmpeg -y -i %st.avi -i %s -c:v copy -c:a copy %s.avi" % (cropfile,audiotmp,cropfile)) + output = subprocess.call(command, shell=True, stdout=None) + + if output != 0: + pdb.set_trace() + + print('Written %s'%cropfile) + + os.remove(cropfile+'t.avi') + + print('Mean pos: x %.2f y %.2f s %.2f'%(np.mean(dets['x']),np.mean(dets['y']),np.mean(dets['s']))) + + return {'track':track, 'proc_track':dets} + +# ========== ========== ========== ========== +# # FACE DETECTION +# ========== ========== ========== ========== + +def inference_video(opt): + + DET = S3FD(device='cuda') + + flist = glob.glob(os.path.join(opt.frames_dir,opt.reference,'*.jpg')) + flist.sort() + + dets = [] + + for fidx, fname in enumerate(flist): + + start_time = time.time() + + image = cv2.imread(fname) + + image_np = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + bboxes = DET.detect_faces(image_np, conf_th=0.9, scales=[opt.facedet_scale]) + + dets.append([]); + for bbox in bboxes: + dets[-1].append({'frame':fidx, 'bbox':(bbox[:-1]).tolist(), 'conf':bbox[-1]}) + + elapsed_time = time.time() - start_time + + print('%s-%05d; %d dets; %.2f Hz' % (os.path.join(opt.avi_dir,opt.reference,'video.avi'),fidx,len(dets[-1]),(1/elapsed_time))) + + savepath = os.path.join(opt.work_dir,opt.reference,'faces.pckl') + + with open(savepath, 'wb') as fil: + pickle.dump(dets, fil) + + return dets + +# ========== ========== ========== ========== +# # SCENE DETECTION +# ========== ========== ========== ========== + +def scene_detect(opt): + + video_manager = VideoManager([os.path.join(opt.avi_dir,opt.reference,'video.avi')]) + stats_manager = StatsManager() + scene_manager = SceneManager(stats_manager) + # Add ContentDetector algorithm (constructor takes detector options like threshold). + scene_manager.add_detector(ContentDetector()) + base_timecode = video_manager.get_base_timecode() + + video_manager.set_downscale_factor() + + video_manager.start() + + scene_manager.detect_scenes(frame_source=video_manager) + + scene_list = scene_manager.get_scene_list(base_timecode) + + savepath = os.path.join(opt.work_dir,opt.reference,'scene.pckl') + + if scene_list == []: + scene_list = [(video_manager.get_base_timecode(),video_manager.get_current_timecode())] + + with open(savepath, 'wb') as fil: + pickle.dump(scene_list, fil) + + print('%s - scenes detected %d'%(os.path.join(opt.avi_dir,opt.reference,'video.avi'),len(scene_list))) + + return scene_list + + +# ========== ========== ========== ========== +# # EXECUTE DEMO +# ========== ========== ========== ========== + +# ========== DELETE EXISTING DIRECTORIES ========== + +if os.path.exists(os.path.join(opt.work_dir,opt.reference)): + rmtree(os.path.join(opt.work_dir,opt.reference)) + +if os.path.exists(os.path.join(opt.crop_dir,opt.reference)): + rmtree(os.path.join(opt.crop_dir,opt.reference)) + +if os.path.exists(os.path.join(opt.avi_dir,opt.reference)): + rmtree(os.path.join(opt.avi_dir,opt.reference)) + +if os.path.exists(os.path.join(opt.frames_dir,opt.reference)): + rmtree(os.path.join(opt.frames_dir,opt.reference)) + +if os.path.exists(os.path.join(opt.tmp_dir,opt.reference)): + rmtree(os.path.join(opt.tmp_dir,opt.reference)) + +# ========== MAKE NEW DIRECTORIES ========== + +os.makedirs(os.path.join(opt.work_dir,opt.reference)) +os.makedirs(os.path.join(opt.crop_dir,opt.reference)) +os.makedirs(os.path.join(opt.avi_dir,opt.reference)) +os.makedirs(os.path.join(opt.frames_dir,opt.reference)) +os.makedirs(os.path.join(opt.tmp_dir,opt.reference)) + +# ========== CONVERT VIDEO AND EXTRACT FRAMES ========== + +command = ("ffmpeg -y -i %s -qscale:v 2 -async 1 -r 25 %s" % (opt.videofile,os.path.join(opt.avi_dir,opt.reference,'video.avi'))) +output = subprocess.call(command, shell=True, stdout=None) + +command = ("ffmpeg -y -i %s -qscale:v 2 -threads 1 -f image2 %s" % (os.path.join(opt.avi_dir,opt.reference,'video.avi'),os.path.join(opt.frames_dir,opt.reference,'%06d.jpg'))) +output = subprocess.call(command, shell=True, stdout=None) + +command = ("ffmpeg -y -i %s -ac 1 -vn -acodec pcm_s16le -ar 16000 %s" % (os.path.join(opt.avi_dir,opt.reference,'video.avi'),os.path.join(opt.avi_dir,opt.reference,'audio.wav'))) +output = subprocess.call(command, shell=True, stdout=None) + +# ========== FACE DETECTION ========== + +faces = inference_video(opt) + +# ========== SCENE DETECTION ========== + +scene = scene_detect(opt) + +# ========== FACE TRACKING ========== + +alltracks = [] +vidtracks = [] + +for shot in scene: + + if shot[1].frame_num - shot[0].frame_num >= opt.min_track : + alltracks.extend(track_shot(opt,faces[shot[0].frame_num:shot[1].frame_num])) + +# ========== FACE TRACK CROP ========== + +for ii, track in enumerate(alltracks): + vidtracks.append(crop_video(opt,track,os.path.join(opt.crop_dir,opt.reference,'%05d'%ii))) + +# ========== SAVE RESULTS ========== + +savepath = os.path.join(opt.work_dir,opt.reference,'tracks.pckl') + +with open(savepath, 'wb') as fil: + pickle.dump(vidtracks, fil) + +rmtree(os.path.join(opt.tmp_dir,opt.reference)) diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/run_syncnet.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/run_syncnet.py new file mode 100644 index 00000000..45099fd6 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/run_syncnet.py @@ -0,0 +1,45 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import time, pdb, argparse, subprocess, pickle, os, gzip, glob + +from SyncNetInstance import * + +# ==================== PARSE ARGUMENT ==================== + +parser = argparse.ArgumentParser(description = "SyncNet"); +parser.add_argument('--initial_model', type=str, default="data/syncnet_v2.model", help=''); +parser.add_argument('--batch_size', type=int, default='20', help=''); +parser.add_argument('--vshift', type=int, default='15', help=''); +parser.add_argument('--data_dir', type=str, default='data/work', help=''); +parser.add_argument('--videofile', type=str, default='', help=''); +parser.add_argument('--reference', type=str, default='', help=''); +opt = parser.parse_args(); + +setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi')) +setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp')) +setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork')) +setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop')) + + +# ==================== LOAD MODEL AND FILE LIST ==================== + +s = SyncNetInstance(); + +s.loadParameters(opt.initial_model); +print("Model %s loaded."%opt.initial_model); + +flist = glob.glob(os.path.join(opt.crop_dir,opt.reference,'0*.avi')) +flist.sort() + +# ==================== GET OFFSETS ==================== + +dists = [] +for idx, fname in enumerate(flist): + offset, conf, dist = s.evaluate(opt,videofile=fname) + dists.append(dist) + +# ==================== PRINT RESULTS TO FILE ==================== + +with open(os.path.join(opt.work_dir,opt.reference,'activesd.pckl'), 'wb') as fil: + pickle.dump(dists, fil) diff --git a/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/run_visualise.py b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/run_visualise.py new file mode 100644 index 00000000..85d89253 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/LSE-C-D/run_visualise.py @@ -0,0 +1,88 @@ +#!/usr/bin/python +#-*- coding: utf-8 -*- + +import torch +import numpy +import time, pdb, argparse, subprocess, pickle, os, glob +import cv2 + +from scipy import signal + +# ==================== PARSE ARGUMENT ==================== + +parser = argparse.ArgumentParser(description = "SyncNet"); +parser.add_argument('--data_dir', type=str, default='data/work', help=''); +parser.add_argument('--videofile', type=str, default='', help=''); +parser.add_argument('--reference', type=str, default='', help=''); +parser.add_argument('--frame_rate', type=int, default=25, help='Frame rate'); +opt = parser.parse_args(); + +setattr(opt,'avi_dir',os.path.join(opt.data_dir,'pyavi')) +setattr(opt,'tmp_dir',os.path.join(opt.data_dir,'pytmp')) +setattr(opt,'work_dir',os.path.join(opt.data_dir,'pywork')) +setattr(opt,'crop_dir',os.path.join(opt.data_dir,'pycrop')) +setattr(opt,'frames_dir',os.path.join(opt.data_dir,'pyframes')) + +# ==================== LOAD FILES ==================== + +with open(os.path.join(opt.work_dir,opt.reference,'tracks.pckl'), 'rb') as fil: + tracks = pickle.load(fil, encoding='latin1') + +with open(os.path.join(opt.work_dir,opt.reference,'activesd.pckl'), 'rb') as fil: + dists = pickle.load(fil, encoding='latin1') + +flist = glob.glob(os.path.join(opt.frames_dir,opt.reference,'*.jpg')) +flist.sort() + +# ==================== SMOOTH FACES ==================== + +faces = [[] for i in range(len(flist))] + +for tidx, track in enumerate(tracks): + + mean_dists = numpy.mean(numpy.stack(dists[tidx],1),1) + minidx = numpy.argmin(mean_dists,0) + minval = mean_dists[minidx] + + fdist = numpy.stack([dist[minidx] for dist in dists[tidx]]) + fdist = numpy.pad(fdist, (3,3), 'constant', constant_values=10) + + fconf = numpy.median(mean_dists) - fdist + fconfm = signal.medfilt(fconf,kernel_size=9) + + for fidx, frame in enumerate(track['track']['frame'].tolist()) : + faces[frame].append({'track': tidx, 'conf':fconfm[fidx], 's':track['proc_track']['s'][fidx], 'x':track['proc_track']['x'][fidx], 'y':track['proc_track']['y'][fidx]}) + +# ==================== ADD DETECTIONS TO VIDEO ==================== + +first_image = cv2.imread(flist[0]) + +fw = first_image.shape[1] +fh = first_image.shape[0] + +fourcc = cv2.VideoWriter_fourcc(*'XVID') +vOut = cv2.VideoWriter(os.path.join(opt.avi_dir,opt.reference,'video_only.avi'), fourcc, opt.frame_rate, (fw,fh)) + +for fidx, fname in enumerate(flist): + + image = cv2.imread(fname) + + for face in faces[fidx]: + + clr = max(min(face['conf']*25,255),0) + + cv2.rectangle(image,(int(face['x']-face['s']),int(face['y']-face['s'])),(int(face['x']+face['s']),int(face['y']+face['s'])),(0,clr,255-clr),3) + cv2.putText(image,'Track %d, Conf %.3f'%(face['track'],face['conf']), (int(face['x']-face['s']),int(face['y']-face['s'])),cv2.FONT_HERSHEY_SIMPLEX,0.5,(255,255,255),2) + + vOut.write(image) + + print('Frame %d'%fidx) + +vOut.release() + +# ========== COMBINE AUDIO AND VIDEO FILES ========== + +command = ("ffmpeg -y -i %s -i %s -c:v copy -c:a copy %s" % (os.path.join(opt.avi_dir,opt.reference,'video_only.avi'),os.path.join(opt.avi_dir,opt.reference,'audio.wav'),os.path.join(opt.avi_dir,opt.reference,'video_out.avi'))) #-async 1 +output = subprocess.call(command, shell=True, stdout=None) + + diff --git a/LiveSpeechPortraits/source_code/judge_models/NIQE/niqe.py b/LiveSpeechPortraits/source_code/judge_models/NIQE/niqe.py new file mode 100644 index 00000000..900339e1 --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/NIQE/niqe.py @@ -0,0 +1,91 @@ +import cv2 +import numpy as np +import torch +from skimage import img_as_float +from skimage.feature import local_binary_pattern +from skimage.measure import shannon_entropy +from scipy import stats + +def compute_niqe(image): + """ + 计算图像的 NIQE 分数(无需使用 pyiqa 库)。 + 通过特征提取和统计比对来计算 NIQE 分数。 + + 参数: + - image (numpy.ndarray): 输入的图像。 + + 返回: + - niqe_score (float): 图像的 NIQE 分数。 + """ + + # 将图像转换为浮动类型,范围[0, 1] + image = img_as_float(image) + + # 计算图像的局部二值模式(LBP) + lbp = local_binary_pattern(image, P=8, R=1, method="uniform") + + # 计算图像的熵(信息量) + entropy = shannon_entropy(image) + + # 计算图像的均值、标准差 + mean = np.mean(image) + std = np.std(image) + + # 计算图像的偏度(Skewness)和峰度(Kurtosis) + skewness = stats.skew(image.flatten()) + kurtosis = stats.kurtosis(image.flatten()) + + # 通过特征合成一个简单的 NIQE 评分(简单示例,实际使用时需要根据预训练模型调整) + niqe_score = (mean + std + skewness + kurtosis + entropy) / 5 + + return niqe_score + +def calculate_niqe_for_video(video_path): + """ + 计算视频每一帧的 NIQE 分数,并返回视频的平均 NIQE 分数。 + + 参数: + - video_path (str): 视频文件的路径。 + + 返回: + - average_niqe_score (float): 视频的平均 NIQE 分数。 + """ + # 打开 AVI 视频文件 + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + print("Error: Could not open video.") + return None + + niqe_scores = [] + + while True: + # 读取一帧 + ret, frame = cap.read() + if not ret: + break + + # 将 BGR 帧转换为灰度图像 + frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + + # 计算 NIQE 分数 + score = compute_niqe(frame_gray) + niqe_scores.append(score) + + # 打印或存储分数 + print(f"Frame {len(niqe_scores)} NIQE score: {score}") + + # 释放视频捕获对象 + cap.release() + + # 计算平均的 NIQE 分数 + average_niqe_score = np.mean(niqe_scores) + print(f"Average NIQE score for the video: {average_niqe_score}") + + return average_niqe_score + +# 示例使用: +video_path = 'E:\\Code\\DesktopCode\\LiveSpeechPortraits\\results\\May\\May_short\\May_short.avi' # 替换为你的视频文件路径 +average_score = calculate_niqe_for_video(video_path) +if average_score is not None: + print(f"视频的平均 NIQE 分数为: {average_score}") diff --git a/LiveSpeechPortraits/source_code/judge_models/run_judge.py b/LiveSpeechPortraits/source_code/judge_models/run_judge.py new file mode 100644 index 00000000..957d4e4d --- /dev/null +++ b/LiveSpeechPortraits/source_code/judge_models/run_judge.py @@ -0,0 +1,100 @@ +import sys +import subprocess +import os +import argparse + +def compute_L1_PSNR_SSIM_LPIPS(gt_video, gen_video): + # 调用./L1_PSNR_SSIM_LPIPS目录下的eval.py文件 + try: + from L1_PSNR_SSIM_LPIPS.eval import compute_all_metrics + metrics = compute_all_metrics(gt_video, gen_video) + return metrics + except ImportError: + print("Error: Could not import the function 'compute_all_metrics' from eval.py.") + sys.exit(1) + +def compute_FID(gt_video, gen_video): + script_directory = os.path.dirname(os.path.abspath(__file__)) + script_directory = os.path.join(script_directory, 'fid_tmp') + try: + from FID.fid_eval import compute_fid_for_videos + value = compute_fid_for_videos(gt_video, gen_video, output_dir=script_directory) + return value + except ImportError: + print("Error: Could not import the function 'compute_all_metrics' from eval.py.") + sys.exit(1) + +def compute_LSE(gen_video): + # 调用./LSE-C-D目录下的judge_lse.py文件 + try: + subprocess.run(['python', './LSE-C-D/judge_lse.py', gen_video], check=True) + # 从结果文件中读取LSE-C和LSE-D + with open('./LSE-C-D/all_scores.txt', 'r') as file: + scores = file.readline().split() + lse_c, lse_d = float(scores[0]), float(scores[1]) + return lse_c, lse_d + except subprocess.CalledProcessError as e: + print(f"Error during LSE calculation: {e}") + sys.exit(1) + +def main(): + # 创建解析器 + parser = argparse.ArgumentParser(description="Evaluate generated video against ground truth.") + + # 添加参数 + parser.add_argument('--gt_video', type=str, required=True, help="Path to the ground truth video") + parser.add_argument('--gen_video', type=str, required=True, help="Path to the generated video") + + # 解析命令行参数 + args = parser.parse_args() + + # 获取参数值 + gt_video = args.gt_video + gen_video = args.gen_video + + # Step 1: 计算FID + print("=================================================") + print("[Computing] FID") + fid = compute_FID(gt_video, gen_video) + print("=================================================") + print("[Finished] FID") + + # Step 2: 计算L1、PSNR、SSIM、LPIPS + print("=================================================") + print("[Computing] L1 && PSNR && SSIM && LPIPS") + metrics = compute_L1_PSNR_SSIM_LPIPS(gt_video, gen_video) + print("=================================================") + print("[Finished] L1 && PSNR && SSIM && LPIPS") + l1, psnr, ssim, lpips = metrics[0], metrics[1], metrics[2], metrics[3] + + # Step 3: 计算LSE-C和LSE-D + print("=================================================") + print("[Computing] LSE-C && LSE-D") + print("=================================================") + print("This will take some time...") + lse_c, lse_d = compute_LSE(gen_video) + print("=================================================") + print("[Finished] LSE-C && LSE-D") + + # Step 3: 整合所有评估结果 + result = { + "L1": l1, + "PSNR": psnr, + "SSIM": ssim, + "LPIPS": lpips, + "LSE-C": lse_c, + "LSE-D": lse_d, + "FID": fid + } + + # 输出结果 + print("=================================================") + print("Evaluation Results:") + print("=================================================") + for metric, value in result.items(): + print(f"{metric}: {value}") + print("=================================================") + return result + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/models/__init__.py b/LiveSpeechPortraits/source_code/models/__init__.py new file mode 100644 index 00000000..9e434830 --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/__init__.py @@ -0,0 +1,142 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" +import os +import importlib +import numpy as np +import torch +import torch.nn as nn +from models.base_model import BaseModel + + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit() + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance + + +def save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, modelG, modelD, end_of_epoch=False): + if not end_of_epoch: + if total_steps % opt.save_latest_freq == 0: + visualizer.vis_print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) + modelG.module.save('latest') + modelD.module.save('latest') + np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') + else: + if epoch % opt.save_epoch_freq == 0: + visualizer.vis_print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) + modelG.module.save('latest') + modelD.module.save('latest') + modelG.module.save(epoch) + modelD.module.save(epoch) + np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') + + +def update_models(opt, epoch, modelG, modelD, dataset_warp): + ### linearly decay learning rate after certain iterations + if epoch > opt.niter: + modelG.module.update_learning_rate(epoch, 'G') + modelD.module.update_learning_rate(epoch, 'D') + + ### gradually grow training sequence length + if (epoch % opt.niter_step) == 0: + dataset_warp.dataset.update_training_batch(epoch//opt.niter_step) +# modelG.module.update_training_batch(epoch//opt.niter_step) + + ### finetune all scales + if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): + modelG.module.update_fixed_params() + + +class myModel(nn.Module): + def __init__(self, opt, model): + super(myModel, self).__init__() + self.opt = opt + self.module = model + self.model = nn.DataParallel(model, device_ids=opt.gpu_ids) + self.bs_per_gpu = int(np.ceil(float(opt.batch_size) / len(opt.gpu_ids))) # batch size for each GPU + self.pad_bs = self.bs_per_gpu * len(opt.gpu_ids) - opt.batch_size + + def forward(self, *inputs, **kwargs): + inputs = self.add_dummy_to_tensor(inputs, self.pad_bs) + outputs = self.model(*inputs, **kwargs, dummy_bs=self.pad_bs) + if self.pad_bs == self.bs_per_gpu: # gpu 0 does 0 batch but still returns 1 batch + return self.remove_dummy_from_tensor(outputs, 1) + return outputs + + def add_dummy_to_tensor(self, tensors, add_size=0): + if add_size == 0 or tensors is None: return tensors + if type(tensors) == list or type(tensors) == tuple: + return [self.add_dummy_to_tensor(tensor, add_size) for tensor in tensors] + + if isinstance(tensors, torch.Tensor): + dummy = torch.zeros_like(tensors)[:add_size] + tensors = torch.cat([dummy, tensors]) + return tensors + + def remove_dummy_from_tensor(self, tensors, remove_size=0): + if remove_size == 0 or tensors is None: return tensors + if type(tensors) == list or type(tensors) == tuple: + return [self.remove_dummy_from_tensor(tensor, remove_size) for tensor in tensors] + + if isinstance(tensors, torch.Tensor): + tensors = tensors[remove_size:] + return tensors + + diff --git a/LiveSpeechPortraits/source_code/models/audio2feature.py b/LiveSpeechPortraits/source_code/models/audio2feature.py new file mode 100644 index 00000000..91f81755 --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/audio2feature.py @@ -0,0 +1,84 @@ +import torch.nn as nn +from .networks import WaveNet + + + +class Audio2Feature(nn.Module): + def __init__(self, opt): + super(Audio2Feature, self).__init__() + self.opt = opt + opt.A2L_wavenet_input_channels = opt.APC_hidden_size + if self.opt.loss == 'GMM': + output_size = (2 * opt.A2L_GMM_ndim + 1) * opt.A2L_GMM_ncenter + elif self.opt.loss == 'L2': + num_pred = opt.predict_length + output_size = opt.A2L_GMM_ndim * num_pred + # define networks + if opt.feature_decoder == 'WaveNet': + self.WaveNet = WaveNet(opt.A2L_wavenet_residual_layers, + opt.A2L_wavenet_residual_blocks, + opt.A2L_wavenet_residual_channels, + opt.A2L_wavenet_dilation_channels, + opt.A2L_wavenet_skip_channels, + opt.A2L_wavenet_kernel_size, + opt.time_frame_length, + opt.A2L_wavenet_use_bias, + opt.A2L_wavenet_cond, + opt.A2L_wavenet_input_channels, + opt.A2L_GMM_ncenter, + opt.A2L_GMM_ndim, + output_size) + self.item_length = self.WaveNet.receptive_field + opt.time_frame_length - 1 + elif opt.feature_decoder == 'LSTM': + self.downsample = nn.Sequential( + nn.Linear(in_features=opt.APC_hidden_size * 2, out_features=opt.APC_hidden_size), + nn.BatchNorm1d(opt.APC_hidden_size), + nn.LeakyReLU(0.2), + nn.Linear(opt.APC_hidden_size, opt.APC_hidden_size), + ) + self.LSTM = nn.LSTM(input_size=opt.APC_hidden_size, + hidden_size=256, + num_layers=3, + dropout=0, + bidirectional=False, + batch_first=True) + self.fc = nn.Sequential( + nn.Linear(in_features=256, out_features=512), + nn.BatchNorm1d(512), + nn.LeakyReLU(0.2), + nn.Linear(512, 512), + nn.BatchNorm1d(512), + nn.LeakyReLU(0.2), + nn.Linear(512, output_size)) + + + def forward(self, audio_features): + ''' + Args: + audio_features: [b, T, ndim] + ''' + if self.opt.feature_decoder == 'WaveNet': + pred = self.WaveNet.forward(audio_features.permute(0,2,1)) + elif self.opt.feature_decoder == 'LSTM': + bs, item_len, ndim = audio_features.shape + # new in 0324 + audio_features = audio_features.reshape(bs, -1, ndim*2) + down_audio_feats = self.downsample(audio_features.reshape(-1, ndim*2)).reshape(bs, int(item_len/2), ndim) + output, (hn, cn) = self.LSTM(down_audio_feats) +# output, (hn, cn) = self.LSTM(audio_features) + pred = self.fc(output.reshape(-1, 256)).reshape(bs, int(item_len/2), -1) +# pred = self.fc(output.reshape(-1, 256)).reshape(bs, item_len, -1)[:, -self.opt.time_frame_length:, :] + + return pred + + + + + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/models/audio2feature_model.py b/LiveSpeechPortraits/source_code/models/audio2feature_model.py new file mode 100644 index 00000000..d029fdf3 --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/audio2feature_model.py @@ -0,0 +1,157 @@ +import numpy as np +import torch + +from .base_model import BaseModel +from . import networks +from . import audio2feature + + + + + +class Audio2FeatureModel(BaseModel): + def __init__(self, opt): + """Initialize the Audio2Feature class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + + # specify the models you want to save to the disk. The training/test scripts will call and + # define networks + self.model_names = ['Audio2Feature'] + self.Audio2Feature = networks.init_net(audio2feature.Audio2Feature(opt), init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) + + # define only during training time + if self.isTrain: + # losses + self.featureL2loss = torch.nn.MSELoss().to(self.device) + # optimizer + self.optimizer = torch.optim.Adam([{'params':self.Audio2Feature.parameters(), + 'initial_lr': opt.lr}], lr=opt.lr, betas=(0.9, 0.99)) + + self.optimizers.append(self.optimizer) + + if opt.continue_train: + self.resume_training() + + + def resume_training(self): + opt = self.opt + ### if continue training, recover previous states + print('Resuming from epoch %s ' % (opt.load_epoch)) + # change epoch count & update schedule settings + opt.epoch_count = int(opt.load_epoch) + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + # print lerning rate + lr = self.optimizers[0].param_groups[0]['lr'] + print('update learning rate: {} -> {}'.format(opt.lr, lr)) + + + + def set_input(self, data, data_info=None): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + """ + + self.audio_feats, self.target_info = data +# b, item_length, mel_channels, width = self.audio_feats.shape + self.audio_feats = self.audio_feats.to(self.device) + self.target_info = self.target_info.to(self.device) + + # gaussian noise +# if self.opt.gaussian_noise: +# self.audio_feats = self.opt.gaussian_noise_scale * torch.randn(self.audio_feats.shape).cuda() +# self.target_info += self.opt.gaussian_noise_scale * torch.randn(self.target_info.shape).cuda() + + + + def forward(self): + ''' + Args: + history_landmarks: [b, T, ndim] + audio_features: [b, 1, nfeas, nwins] + Returns: + preds: [b, T, output_channels] + ''' + self.preds = self.Audio2Feature.forward(self.audio_feats) + + + + def calculate_loss(self): + """ calculate loss in detail, only forward pass included""" + if self.opt.loss == 'GMM': + b, T, _ = self.target_info.shape + self.loss_GMM = self.criterion_GMM(self.preds, self.target_info) + self.loss = self.loss_GMM + + elif self.opt.loss == 'L2': + frame_future = self.opt.frame_future + if not frame_future == 0: + self.loss = self.featureL2loss(self.preds[:, frame_future:], self.target_info[:, :-frame_future]) * 1000 + else: + self.loss = self.featureL2loss(self.preds, self.target_info) * 1000 + + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + self.calculate_loss() + self.loss.backward() + + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.optimizer.zero_grad() # clear optimizer parameters grad + self.forward() # forward pass + self.backward() # calculate loss and gradients + self.optimizer.step() # update gradients + + + def validate(self): + """ validate process """ + with torch.no_grad(): + self.forward() + self.calculate_loss() + + + def generate_sequences(self, audio_feats, sample_rate = 16000, fps=60, fill_zero=True, opt=[]): + ''' generate landmark sequences given audio and a initialized landmark. + Note that the audio input should have the same sample rate as the training. + Args: + audio_sequences: [n,], in numpy + init_landmarks: [npts, 2], in numpy + sample_rate: audio sample rate, should be same as training process. + method(str): optional, how to generate the sequence, indeed it is the + loss function during training process. Options are 'L2' or 'GMM'. + Reutrns: + landmark_sequences: [T, npts, 2] predition landmark sequences + ''' + + frame_future = opt.frame_future + nframe = int(audio_feats.shape[0] / 2) + + if not frame_future == 0: + audio_feats_insert = np.repeat(audio_feats[-1], 2 * (frame_future)).reshape(-1, 2 * (frame_future)).T + audio_feats = np.concatenate([audio_feats, audio_feats_insert]) + + + # evaluate mode + self.Audio2Feature.eval() + + with torch.no_grad(): + input = torch.from_numpy(audio_feats).unsqueeze(0).float().to(self.device) + preds = self.Audio2Feature.forward(input) + + # drop first frame future results + if not frame_future == 0: + preds = preds[0, frame_future:].cpu().detach().numpy() + else: + preds = preds[0, :].cpu().detach().numpy() + + assert preds.shape[0] == nframe + + + return preds + + + \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/models/audio2headpose.py b/LiveSpeechPortraits/source_code/models/audio2headpose.py new file mode 100644 index 00000000..b6afe163 --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/audio2headpose.py @@ -0,0 +1,108 @@ +import torch.nn as nn + +from .networks import WaveNet + + + +class Audio2Headpose(nn.Module): + def __init__(self, opt): + super(Audio2Headpose, self).__init__() + self.opt = opt + if self.opt.loss == 'GMM': + output_size = (2 * opt.A2H_GMM_ndim + 1) * opt.A2H_GMM_ncenter + elif self.opt.loss == 'L2': + output_size = opt.A2H_GMM_ndim + # define networks + self.audio_downsample = nn.Sequential( + nn.Linear(in_features=opt.APC_hidden_size * 2, out_features=opt.APC_hidden_size), + nn.BatchNorm1d(opt.APC_hidden_size), + nn.LeakyReLU(0.2), + nn.Linear(opt.APC_hidden_size, opt.APC_hidden_size), + ) + + self.WaveNet = WaveNet(opt.A2H_wavenet_residual_layers, + opt.A2H_wavenet_residual_blocks, + opt.A2H_wavenet_residual_channels, + opt.A2H_wavenet_dilation_channels, + opt.A2H_wavenet_skip_channels, + opt.A2H_wavenet_kernel_size, + opt.time_frame_length, + opt.A2H_wavenet_use_bias, + True, + opt.A2H_wavenet_input_channels, + opt.A2H_GMM_ncenter, + opt.A2H_GMM_ndim, + output_size, + opt.A2H_wavenet_cond_channels) + self.item_length = self.WaveNet.receptive_field + opt.time_frame_length - 1 + + + def forward(self, history_info, audio_features): + ''' + Args: + history_info: [b, T, ndim] + audio_features: [b, 1, nfeas, nwins] + ''' + # APC features: [b, item_length, APC_hidden_size] ==> [b, APC_hidden_size, item_length] + bs, item_len, ndim = audio_features.shape + down_audio_feats = self.audio_downsample(audio_features.reshape(-1, ndim)).reshape(bs, item_len, -1) + pred = self.WaveNet.forward(history_info.permute(0,2,1), down_audio_feats.transpose(1,2)) + + + return pred + + + + +class Audio2Headpose_LSTM(nn.Module): + def __init__(self, opt): + super(Audio2Headpose_LSTM, self).__init__() + self.opt = opt + if self.opt.loss == 'GMM': + output_size = (2 * opt.A2H_GMM_ndim + 1) * opt.A2H_GMM_ncenter + elif self.opt.loss == 'L2': + output_size = opt.A2H_GMM_ndim + # define networks + self.audio_downsample = nn.Sequential( + nn.Linear(in_features=opt.APC_hidden_size * 2, out_features=opt.APC_hidden_size), + nn.BatchNorm1d(opt.APC_hidden_size), + nn.LeakyReLU(0.2), + nn.Linear(opt.APC_hidden_size, opt.APC_hidden_size), + ) + + self.LSTM = nn.LSTM(input_size=opt.APC_hidden_size, + hidden_size=256, + num_layers=3, + dropout=0, + bidirectional=False, + batch_first=True) + self.fc = nn.Sequential( + nn.Linear(in_features=256, out_features=512), + nn.BatchNorm1d(512), + nn.LeakyReLU(0.2), + nn.Linear(512, 512), + nn.BatchNorm1d(512), + nn.LeakyReLU(0.2), + nn.Linear(512, output_size)) + + + def forward(self, audio_features): + ''' + Args: + history_info: [b, T, ndim] + audio_features: [b, 1, nfeas, nwins] + ''' + # APC features: [b, item_length, APC_hidden_size] ==> [b, APC_hidden_size, item_length] + bs, item_len, ndim = audio_features.shape + down_audio_feats = self.audio_downsample(audio_features.reshape(-1, ndim)).reshape(bs, item_len, -1) + output, (hn, cn) = self.LSTM(down_audio_feats) + pred = self.fc(output.reshape(-1, 256)).reshape(bs, item_len, -1) + + + return pred + + + + + + \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/models/audio2headpose_model.py b/LiveSpeechPortraits/source_code/models/audio2headpose_model.py new file mode 100644 index 00000000..5b713891 --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/audio2headpose_model.py @@ -0,0 +1,208 @@ +import numpy as np +import torch +from tqdm import tqdm + +from .base_model import BaseModel +from . import networks +from . import audio2headpose +from .losses import GMMLogLoss, Sample_GMM +import torch.nn as nn + + + +class Audio2HeadposeModel(BaseModel): + def __init__(self, opt): + """Initialize the Audio2Headpose class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + + # specify the models you want to save to the disk. The training/test scripts will call and + # define networks + self.model_names = ['Audio2Headpose'] + if opt.feature_decoder == 'WaveNet': + self.Audio2Headpose = networks.init_net(audio2headpose.Audio2Headpose(opt), init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) + elif opt.feature_decoder == 'LSTM': + self.Audio2Headpose = networks.init_net(audio2headpose.Audio2Headpose_LSTM(opt), init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) + + # define only during training time + if self.isTrain: + # losses + self.criterion_GMM = GMMLogLoss(opt.A2H_GMM_ncenter, opt.A2H_GMM_ndim, opt.A2H_GMM_sigma_min).to(self.device) + self.criterion_L2 = nn.MSELoss().cuda() + # optimizer + self.optimizer = torch.optim.Adam([{'params':self.Audio2Headpose.parameters(), + 'initial_lr': opt.lr}], lr=opt.lr, betas=(0.9, 0.99)) + + self.optimizers.append(self.optimizer) + + if opt.continue_train: + self.resume_training() + + + def resume_training(self): + opt = self.opt + ### if continue training, recover previous states + print('Resuming from epoch %s ' % (opt.load_epoch)) + # change epoch count & update schedule settings + opt.epoch_count = int(opt.load_epoch) + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + # print lerning rate + lr = self.optimizers[0].param_groups[0]['lr'] + print('update learning rate: {} -> {}'.format(opt.lr, lr)) + + + + def set_input(self, data, data_info=None): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + """ + if self.opt.feature_decoder == 'WaveNet': + self.headpose_audio_feats, self.history_headpose, self.target_headpose = data + self.headpose_audio_feats = self.headpose_audio_feats.to(self.device) + self.history_headpose = self.history_headpose.to(self.device) + self.target_headpose = self.target_headpose.to(self.device) + elif self.opt.feature_decoder == 'LSTM': + self.headpose_audio_feats, self.target_headpose = data + self.headpose_audio_feats = self.headpose_audio_feats.to(self.device) + self.target_headpose = self.target_headpose.to(self.device) + + + + def forward(self): + ''' + Args: + history_landmarks: [b, T, ndim] + audio_features: [b, 1, nfeas, nwins] + Returns: + preds: [b, T, output_channels] + ''' + + if self.opt.audio_windows == 2: + bs, item_len, ndim = self.headpose_audio_feats.shape + self.headpose_audio_feats = self.headpose_audio_feats.reshape(bs, -1, ndim * 2) + else: + bs, item_len, _, ndim = self.headpose_audio_feats.shape + if self.opt.feature_decoder == 'WaveNet': + self.preds_headpose = self.Audio2Headpose.forward(self.history_headpose, self.headpose_audio_feats) + elif self.opt.feature_decoder == 'LSTM': + self.preds_headpose = self.Audio2Headpose.forward(self.headpose_audio_feats) + + + def calculate_loss(self): + """ calculate loss in detail, only forward pass included""" + if self.opt.loss == 'GMM': + self.loss_GMM = self.criterion_GMM(self.preds_headpose, self.target_headpose) + self.loss = self.loss_GMM + elif self.opt.loss == 'L2': + self.loss_L2 = self.criterion_L2(self.preds_headpose, self.target_headpose) + self.loss = self.loss_L2 + + if not self.opt.smooth_loss == 0: + mu_gen = Sample_GMM(self.preds_headpose, + self.Audio2Headpose.module.WaveNet.ncenter, + self.Audio2Headpose.module.WaveNet.ndim, + sigma_scale=0) + self.smooth_loss = (mu_gen[:,2:] + self.target_headpose[:,:-2] - 2 * self.target_headpose[:,1:-1]).mean(dim=2).abs().mean() + self.loss += self.smooth_loss * self.opt.smooth_loss + + + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + self.calculate_loss() + self.loss.backward() + + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.optimizer.zero_grad() # clear optimizer parameters grad + self.forward() # forward pass + self.backward() # calculate loss and gradients + self.optimizer.step() # update gradients + + + def validate(self): + """ validate process """ + with torch.no_grad(): + self.forward() + self.calculate_loss() + + + def generate_sequences(self, audio_feats, pre_headpose, fill_zero=True, sigma_scale=0.0, opt=[]): + ''' generate landmark sequences given audio and a initialized landmark. + Note that the audio input should have the same sample rate as the training. + Args: + audio_sequences: [n,], in numpy + init_landmarks: [npts, 2], in numpy + sample_rate: audio sample rate, should be same as training process. + method(str): optional, how to generate the sequence, indeed it is the + loss function during training process. Options are 'L2' or 'GMM'. + Reutrns: + landmark_sequences: [T, npts, 2] predition landmark sequences + ''' + + frame_future = opt.frame_future + audio_feats = audio_feats.reshape(-1, 512 * 2) + nframe = audio_feats.shape[0] - frame_future + pred_headpose = np.zeros([nframe, opt.A2H_GMM_ndim]) + + if opt.feature_decoder == 'WaveNet': + # fill zero or not + if fill_zero == True: + # headpose + audio_feats_insert = np.repeat(audio_feats[0], opt.A2H_receptive_field - 1) + audio_feats_insert = audio_feats_insert.reshape(-1, opt.A2H_receptive_field - 1).T + audio_feats = np.concatenate([audio_feats_insert, audio_feats]) + # history headpose + history_headpose = np.repeat(pre_headpose, opt.A2H_receptive_field) + history_headpose = history_headpose.reshape(-1, opt.A2H_receptive_field).T + history_headpose = torch.from_numpy(history_headpose).unsqueeze(0).float().to(self.device) + infer_start = 0 + else: + return None + + # evaluate mode + self.Audio2Headpose.eval() + + with torch.no_grad(): + for i in tqdm(range(infer_start, nframe), desc='generating headpose'): + history_start = i - infer_start + input_audio_feats = audio_feats[history_start + frame_future: history_start + frame_future + opt.A2H_receptive_field] + input_audio_feats = torch.from_numpy(input_audio_feats).unsqueeze(0).float().to(self.device) + + if self.opt.feature_decoder == 'WaveNet': + preds = self.Audio2Headpose.forward(history_headpose, input_audio_feats) + elif self.opt.feature_decoder == 'LSTM': + preds = self.Audio2Headpose.forward(input_audio_feats) + + if opt.loss == 'GMM': + pred_data = Sample_GMM(preds, opt.A2H_GMM_ncenter, opt.A2H_GMM_ndim, sigma_scale=sigma_scale) + elif opt.loss == 'L2': + pred_data = preds + + # get predictions + pred_headpose[i] = pred_data[0,0].cpu().detach().numpy() + history_headpose = torch.cat((history_headpose[:,1:,:], pred_data.to(self.device)), dim=1) # add in time-axis + + return pred_headpose + + elif opt.feature_decoder == 'LSTM': + self.Audio2Headpose.eval() + with torch.no_grad(): + input = torch.from_numpy(audio_feats).unsqueeze(0).float().to(self.device) + preds = self.Audio2Headpose.forward(input) + if opt.loss == 'GMM': + pred_data = Sample_GMM(preds, opt.A2H_GMM_ncenter, opt.A2H_GMM_ndim, sigma_scale=sigma_scale) + elif opt.loss == 'L2': + pred_data = preds + # get predictions + pred_headpose = pred_data[0].cpu().detach().numpy() + + return pred_headpose + + + + + \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/models/base_model.py b/LiveSpeechPortraits/source_code/models/base_model.py new file mode 100644 index 00000000..b654005b --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/base_model.py @@ -0,0 +1,272 @@ +import os +import torch +import numpy as np +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this function, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + # get device name: CPU or GPU + # if self.gpu_ids == '-1': + # self.device = torch.device('cpu') + # self.gpu_ids = opt.gpu_ids == [] + # else: + # self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if len(self.gpu_ids) > 0 else torch.device('cpu') + + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + # torch speed up training + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + self.load_networks(opt.load_epoch) + self.print_networks(opt.verbose) + + + def train(self): + """Make models train mode during train time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.train(mode=True) + + + def eval(self): + """Make models eval mode during test time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.eval() + + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self): + """ Return image paths that are used to load current data""" + return self.image_paths + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch, train_info=None): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_%s.pkl' % (epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, name) + torch.save(net.state_dict(), save_path) + if train_info is not None: + epoch, epoch_iter = train_info + iter_path = os.path.join(self.save_dir, 'iter.txt') + np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') + + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + + + for name in self.model_names: + if isinstance(name, str): + if epoch[-3:] == 'pkl': + load_path = epoch + else: + load_filename = '%s_%s.pkl' % (epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, name) +# if isinstance(net, torch.nn.DataParallel): +# net = net.module + if os.path.exists(load_path): + state_dict = torch.load(load_path, map_location=str(self.device)) + if self.device == torch.device('cpu'): + for key in list(state_dict.keys()): + state_dict[key[7:]] = state_dict.pop(key) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + print('loading the model from %s' % load_path) + net.load_state_dict(state_dict, strict=False) + else: + print('No model weight file:', load_path, 'initialize model without pre-trained weights.') + if self.isTrain == False: + raise ValueError('We are now in inference process, no pre-trained model found! Check the model checkpoint!') + + +# if isinstance(net, torch.nn.DataParallel): +# net = net.module + + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + +# state_dict = torch.load(load_path, map_location=str(self.device)) +# if hasattr(state_dict, '_metadata'): +# del state_dict._metadata +# +# # patch InstanceNorm checkpoints prior to 0.4 +# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop +# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) +# net.load_state_dict(state_dict) + + + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/LiveSpeechPortraits/source_code/models/feature2face_D.py b/LiveSpeechPortraits/source_code/models/feature2face_D.py new file mode 100644 index 00000000..73b7ff65 --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/feature2face_D.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn + + +from .networks import MultiscaleDiscriminator +from torch.cuda.amp import autocast as autocast + + + +class Feature2Face_D(nn.Module): + def __init__(self, opt): + super(Feature2Face_D, self).__init__() + # initialize + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor + self.tD = opt.n_frames_D + self.output_nc = opt.output_nc + + # define networks + self.netD = MultiscaleDiscriminator(23 + 3, opt.ndf, opt.n_layers_D, opt.num_D, not opt.no_ganFeat) + + print('---------- Discriminator networks initialized -------------') + print('-----------------------------------------------------------') + + #@autocast() + def forward(self, input): + if self.opt.fp16: + with autocast(): + pred = self.netD(input) + else: + pred = self.netD(input) + + return pred + + + + + + + + + + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/models/feature2face_G.py b/LiveSpeechPortraits/source_code/models/feature2face_G.py new file mode 100644 index 00000000..37234dff --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/feature2face_G.py @@ -0,0 +1,46 @@ +import torch.nn as nn + +from .networks import Feature2FaceGenerator_Unet, Feature2FaceGenerator_normal, Feature2FaceGenerator_large + +from torch.cuda.amp import autocast as autocast + + +class Feature2Face_G(nn.Module): + def __init__(self, opt): + super(Feature2Face_G, self).__init__() + # initialize + self.opt = opt + self.isTrain = opt.isTrain + # define net G + + if opt.size == 'small': + self.netG = Feature2FaceGenerator_Unet(input_nc=23, output_nc=3, num_downs=opt.n_downsample_G, ngf=opt.ngf) + elif opt.size == 'normal': + self.netG = Feature2FaceGenerator_normal(input_nc=13, output_nc=3, num_downs=opt.n_downsample_G, ngf=opt.ngf) + elif opt.size == 'large': + self.netG = Feature2FaceGenerator_large(input_nc=13, output_nc=3, num_downs=opt.n_downsample_G, ngf=opt.ngf) + + print('---------- Generator networks initialized -------------') + print('-------------------------------------------------------') + + + def forward(self, input): + if self.opt.fp16: + with autocast(): + fake_pred = self.netG(input) + else: + fake_pred = self.netG(input) + + return fake_pred + + + + + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/models/feature2face_model.py b/LiveSpeechPortraits/source_code/models/feature2face_model.py new file mode 100644 index 00000000..b2fadb65 --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/feature2face_model.py @@ -0,0 +1,246 @@ +import os +import os.path +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import autocast as autocast + +from . import networks +from . import feature2face_G +from .base_model import BaseModel +from .losses import GANLoss, MaskedL1Loss, VGGLoss + + + +class Feature2FaceModel(BaseModel): + def __init__(self, opt): + """Initialize the Feature2Face class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + self.Tensor = torch.cuda.FloatTensor + # specify the models you want to save to the disk. The training/test scripts will call and + # define networks + self.model_names = ['Feature2Face_G'] + self.Feature2Face_G = networks.init_net(feature2face_G.Feature2Face_G(opt), init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) + if self.isTrain: + if not opt.no_discriminator: + self.model_names += ['Feature2Face_D'] + from . import feature2face_D + self.Feature2Face_D = networks.init_net(feature2face_D.Feature2Face_D(opt), init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) + + + # define only during training time + if self.isTrain: + # define losses names + self.loss_names_G = ['L1', 'VGG', 'Style', 'loss_G_GAN', 'loss_G_FM'] + # criterion + self.criterionMaskL1 = MaskedL1Loss().cuda() + self.criterionL1 = nn.L1Loss().cuda() + self.criterionVGG = VGGLoss.cuda() + self.criterionFlow = nn.L1Loss().cuda() + + # initialize optimizer G + if opt.TTUR: + beta1, beta2 = 0, 0.9 + lr = opt.lr / 2 + else: + beta1, beta2 = opt.beta1, 0.999 + lr = opt.lr + self.optimizer_G = torch.optim.Adam([{'params': self.Feature2Face_G.module.parameters(), + 'initial_lr': lr}], + lr=lr, + betas=(beta1, beta2)) + self.optimizers.append(self.optimizer_G) + + # fp16 training + if opt.fp16: + self.scaler = torch.cuda.amp.GradScaler() + + # discriminator setting + if not opt.no_discriminator: + self.criterionGAN = GANLoss(opt.gan_mode, tensor=self.Tensor) + self.loss_names_D = ['D_real', 'D_fake'] + # initialize optimizer D + if opt.TTUR: + beta1, beta2 = 0, 0.9 + lr = opt.lr * 2 + else: + beta1, beta2 = opt.beta1, 0.999 + lr = opt.lr + self.optimizer_D = torch.optim.Adam([{'params': self.Feature2Face_D.module.netD.parameters(), + 'initial_lr': lr}], + lr=lr, + betas=(beta1, beta2)) + self.optimizers.append(self.optimizer_D) + + + def init_paras(self, dataset): + opt = self.opt + iter_path = os.path.join(self.save_dir, 'iter.txt') + start_epoch, epoch_iter = 1, 0 + ### if continue training, recover previous states + if opt.continue_train: + if os.path.exists(iter_path): + start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) + print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) + # change epoch count & update schedule settings + opt.epoch_count = start_epoch + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + # print lerning rate + lr = self.optimizers[0].param_groups[0]['lr'] + print('update learning rate: {} -> {}'.format(opt.lr, lr)) + else: + print('not found training log, hence training from epoch 1') + # change training sequence length +# if start_epoch > opt.nepochs_step: +# dataset.dataset.update_training_batch((start_epoch-1)//opt.nepochs_step) + + + total_steps = (start_epoch-1) * len(dataset) + epoch_iter + total_steps = total_steps // opt.print_freq * opt.print_freq + + return start_epoch, opt.print_freq, total_steps, epoch_iter + + + + def set_input(self, data, data_info=None): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + """ + self.feature_map, self.cand_image, self.tgt_image, self.facial_mask = \ + data['feature_map'], data['cand_image'], data['tgt_image'], data['weight_mask'] + self.feature_map = self.feature_map.to(self.device) + self.cand_image = self.cand_image.to(self.device) + self.tgt_image = self.tgt_image.to(self.device) +# self.facial_mask = self.facial_mask.to(self.device) + + + def forward(self): + ''' forward pass for feature2Face + ''' + self.input_feature_maps = torch.cat([self.feature_map, self.cand_image], dim=1) + self.fake_pred = self.Feature2Face_G(self.input_feature_maps) + + + + + def backward_G(self): + """Calculate GAN and other loss for the generator""" + # GAN loss + real_AB = torch.cat((self.input_feature_maps, self.tgt_image), dim=1) + fake_AB = torch.cat((self.input_feature_maps, self.fake_pred), dim=1) + pred_real = self.Feature2Face_D(real_AB) + pred_fake = self.Feature2Face_D(fake_AB) + loss_G_GAN = self.criterionGAN(pred_fake, True) + # L1, vgg, style loss + loss_l1 = self.criterionL1(self.fake_pred, self.tgt_image) * self.opt.lambda_L1 +# loss_maskL1 = self.criterionMaskL1(self.fake_pred, self.tgt_image, self.facial_mask * self.opt.lambda_mask) + loss_vgg, loss_style = self.criterionVGG(self.fake_pred, self.tgt_image, style=True) + loss_vgg = torch.mean(loss_vgg) * self.opt.lambda_feat + loss_style = torch.mean(loss_style) * self.opt.lambda_feat + # feature matching loss + loss_FM = self.compute_FeatureMatching_loss(pred_fake, pred_real) + + # combine loss and calculate gradients + + if not self.opt.fp16: + self.loss_G = loss_G_GAN + loss_l1 + loss_vgg + loss_style + loss_FM #+ loss_maskL1 + self.loss_G.backward() + else: + with autocast(): + self.loss_G = loss_G_GAN + loss_l1 + loss_vgg + loss_style + loss_FM #+ loss_maskL1 + self.scaler.scale(self.loss_G).backward() + + self.loss_dict = {**self.loss_dict, **dict(zip(self.loss_names_G, [loss_l1, loss_vgg, loss_style, loss_G_GAN, loss_FM]))} + + + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # GAN loss + real_AB = torch.cat((self.input_feature_maps, self.tgt_image), dim=1) + fake_AB = torch.cat((self.input_feature_maps, self.fake_pred), dim=1) + pred_real = self.Feature2Face_D(real_AB) + pred_fake = self.Feature2Face_D(fake_AB.detach()) + with autocast(): + loss_D_real = self.criterionGAN(pred_real, True) * 2 + loss_D_fake = self.criterionGAN(pred_fake, False) + + self.loss_D = (loss_D_fake + loss_D_real) * 0.5 + + self.loss_dict = dict(zip(self.loss_names_D, [loss_D_real, loss_D_fake])) + + if not self.opt.fp16: + self.loss_D.backward() + else: + self.scaler.scale(self.loss_D).backward() + + + def compute_FeatureMatching_loss(self, pred_fake, pred_real): + # GAN feature matching loss + loss_FM = torch.zeros(1).cuda() + feat_weights = 4.0 / (self.opt.n_layers_D + 1) + D_weights = 1.0 / self.opt.num_D + for i in range(min(len(pred_fake), self.opt.num_D)): + for j in range(len(pred_fake[i])): + loss_FM += D_weights * feat_weights * \ + self.criterionL1(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat + + return loss_FM + + + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + # only train single image generation + ## forward + self.forward() + # update D + self.set_requires_grad(self.Feature2Face_D, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + if not self.opt.fp16: + self.backward_D() # calculate gradients for D + self.optimizer_D.step() # update D's weights + else: + with autocast(): + self.backward_D() + self.scaler.step(self.optimizer_D) + + + # update G + self.set_requires_grad(self.Feature2Face_D, False) # D requires no gradients when optimizing G + self.optimizer_G.zero_grad() # set G's gradients to zero + if not self.opt.fp16: + self.backward_G() # calculate graidents for G + self.optimizer_G.step() # udpate G's weights + else: + with autocast(): + self.backward_G() + self.scaler.step(self.optimizer_G) + self.scaler.update() + + + def inference(self, feature_map, cand_image): + """ inference process """ + with torch.no_grad(): + if cand_image == None: + input_feature_maps = feature_map + else: + input_feature_maps = torch.cat([feature_map, cand_image], dim=1) + if not self.opt.fp16: + fake_pred = self.Feature2Face_G(input_feature_maps) + else: + with autocast(): + fake_pred = self.Feature2Face_G(input_feature_maps) + return fake_pred + + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/models/losses.py b/LiveSpeechPortraits/source_code/models/losses.py new file mode 100644 index 00000000..0834a72d --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/losses.py @@ -0,0 +1,278 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import math +import torch.nn.functional as F + + +class GMMLogLoss(nn.Module): + ''' compute the GMM loss between model output and the groundtruth data. + Args: + ncenter: numbers of gaussian distribution + ndim: dimension of each gaussian distribution + sigma_bias: + sigma_min: current we do not use it. + ''' + def __init__(self, ncenter, ndim, sigma_min=0.03): + super(GMMLogLoss,self).__init__() + self.ncenter = ncenter + self.ndim = ndim + self.sigma_min = sigma_min + + + def forward(self, output, target): + ''' + Args: + output: [b, T, ncenter + ncenter * ndim * 2]: + [:, :, : ncenter] shows each gaussian probability + [:, :, ncenter : ncenter + ndim * ncenter] shows the average values of each dimension of each gaussian + [: ,:, ncenter + ndim * ncenter : ncenter + ndim * 2 * ncenter] show the negative log sigma of each dimension of each gaussian + target: [b, T, ndim], the ground truth target landmark data is shown here + To maximize the log-likelihood equals to minimize the negative log-likelihood. + NOTE: It is unstable to directly compute the log results of sigma, e.g. ln(-0.1) as we need to clip the sigma results + into positive. Hence here we predict the negative log sigma results to avoid numerical instablility, which mean: + `` sigma = 1/exp(predict), predict = -ln(sigma) `` + Also, it will be just the 'B' term below! + Currently we only implement single gaussian distribution, hence the first values of pred are meaningless. + For single gaussian distribution: + L(mu, sigma) = -n/2 * ln(2pi * sigma^2) - 1 / (2 x sigma^2) * sum^n (x_i - mu)^2 (n for prediction times, n=1 for one frame, x_i for gt) + = -1/2 * ln(2pi) - 1/2 * ln(sigma^2) - 1/(2 x sigma^2) * (x - mu)^2 + == min -L(mu, sgima) = 0.5 x ln(2pi) + 0.5 x ln(sigma^2) + 1/(2 x sigma^2) * (x - mu)^2 + = 0.5 x ln_2PI + ln(sigma) + 0.5 x (MU_DIFF/sigma)^2 + = A - B + C + In batch and Time sample, b and T are summed and averaged. + ''' + b, T, _ = target.shape + # read prediction paras + mus = output[:, :, self.ncenter : (self.ncenter + self.ncenter * self.ndim)].view(b, T, self.ncenter, self.ndim) # [b, T, ncenter, ndim] + + # apply min sigma + neg_log_sigmas_out = output[:, :, (self.ncenter + self.ncenter * self.ndim):].view(b, T, self.ncenter, self.ndim) # [b, T, ncenter, ndim] + inv_sigmas_min = torch.ones(neg_log_sigmas_out.size()).cuda() * (1. / self.sigma_min) + inv_sigmas_min_log = torch.log(inv_sigmas_min) + neg_log_sigmas = torch.min(neg_log_sigmas_out, inv_sigmas_min_log) + + inv_sigmas = torch.exp(neg_log_sigmas) + # replicate the target of ncenter to minus mu + target_rep = target.unsqueeze(2).expand(b, T, self.ncenter, self.ndim) # [b, T, ncenter, ndim] + MU_DIFF = target_rep - mus # [b, T, ncenter, ndim] + # sigma process + A = 0.5 * math.log(2 * math.pi) # 0.9189385332046727 + B = neg_log_sigmas # [b, T, ncenter, ndim] + C = 0.5 * (MU_DIFF * inv_sigmas)**2 # [b, T, ncenter, ndim] + negative_loglikelihood = A - B + C # [b, T, ncenter, ndim] + + return negative_loglikelihood.mean() + + +def Sample_GMM(gmm_params, ncenter, ndim, weight_smooth = 0.0, sigma_scale = 0.0): + ''' Sample values from a given a GMM distribution. + Args: + gmm_params: [b, target_length, (2 * ndim + 1) * ncenter], including the + distribution weights, average and sigma + ncenter: numbers of gaussian distribution + ndim: dimension of each gaussian distribution + weight_smooth: float, smooth the gaussian distribution weights + sigma_scale: float, adjust the gaussian scale, larger for sharper prediction, + 0 for zero sigma which always return average values + Returns: + current_sample: [] + ''' + # reshape as [b*T, (2 * ndim + 1) * ncenter] + b, T, _ = gmm_params.shape + gmm_params_cpu = gmm_params.cpu().view(-1, (2 * ndim + 1) * ncenter) + # compute each distrubution probability + prob = nn.functional.softmax(gmm_params_cpu[:, : ncenter] * (1 + weight_smooth), dim=1) + # select the gaussian distribution according to their weights + selected_idx = torch.multinomial(prob, num_samples=1, replacement=True) + + mu = gmm_params_cpu[:, ncenter : ncenter + ncenter * ndim] + # please note that we use -logsigma as output, hence here we need to take the negative + sigma = torch.exp(-gmm_params_cpu[:, ncenter + ncenter * ndim:]) * sigma_scale +# print('sigma average:', sigma.mean()) + + selected_sigma = torch.empty(b*T, ndim).float() + selected_mu = torch.empty(b*T, ndim).float() + current_sample = torch.randn(b*T, ndim).float() +# current_sample = test_sample + + for i in range(b*T): + idx = selected_idx[i, 0] + selected_sigma[i, :] = sigma[i, idx * ndim:(idx + 1) * ndim] + selected_mu[i, :] = mu[i, idx * ndim:(idx + 1) * ndim] + + # sample with sel sigma and sel mean + current_sample = current_sample * selected_sigma + selected_mu + # cur_sample = sel_mu +# return current_sample.unsqueeze(1).cuda() + + if torch.cuda.is_available(): + return current_sample.reshape(b, T, -1).cuda() + else: + return current_sample.reshape(b, T, -1) + + + +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + gpu_id = input.get_device() + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).cuda(gpu_id).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).cuda(gpu_id).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + if isinstance(input[0], list): + loss = 0 + for input_i in input: + pred = input_i[-1] + target_tensor = self.get_target_tensor(pred, target_is_real) + loss += self.loss(pred, target_tensor) + return loss + else: + target_tensor = self.get_target_tensor(input[-1], target_is_real) + return self.loss(input[-1], target_tensor) + + + + +class VGGLoss(nn.Module): + def __init__(self, model=None): + super(VGGLoss, self).__init__() + if model is None: + self.vgg = Vgg19() + else: + self.vgg = model + + self.vgg.cuda() + # self.vgg.eval() + self.criterion = nn.L1Loss() + self.style_criterion = StyleLoss() + self.weights = [1.0, 1.0, 1.0, 1.0, 1.0] + self.style_weights = [1.0, 1.0, 1.0, 1.0, 1.0] + # self.weights = [5.0, 1.0, 0.5, 0.4, 0.8] + # self.style_weights = [10e4, 1000, 50, 15, 50] + + def forward(self, x, y, style=False): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + if style: + # return both perceptual loss and style loss. + style_loss = 0 + for i in range(len(x_vgg)): + this_loss = (self.weights[i] * + self.criterion(x_vgg[i], y_vgg[i].detach())) + this_style_loss = (self.style_weights[i] * + self.style_criterion(x_vgg[i], y_vgg[i].detach())) + loss += this_loss + style_loss += this_style_loss + return loss, style_loss + + for i in range(len(x_vgg)): + this_loss = (self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())) + loss += this_loss + return loss + + +def gram_matrix(input): + a, b, c, d = input.size() # a=batch size(=1) + # b=number of feature maps + # (c,d)=dimensions of a f. map (N=c*d) + features = input.view(a * b, c * d) # resise F_XL into \hat F_XL + G = torch.mm(features, features.t()) # compute the gram product + # we 'normalize' the values of the gram matrix + # by dividing by the number of element in each feature maps. + return G.div(a * b * c * d) + + +class StyleLoss(nn.Module): + def __init__(self): + super(StyleLoss, self).__init__() + + def forward(self, x, y): + Gx = gram_matrix(x) + Gy = gram_matrix(y) + return F.mse_loss(Gx, Gy) * 30000000 + + + +class MaskedL1Loss(nn.Module): + def __init__(self): + super(MaskedL1Loss, self).__init__() + self.criterion = nn.L1Loss() + + def forward(self, input, target, mask): + mask = mask.expand(-1, input.size()[1], -1, -1) + loss = self.criterion(input * mask, target * mask) + return loss + + + +from torchvision import models +class Vgg19(nn.Module): + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + + + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/models/networks.py b/LiveSpeechPortraits/source_code/models/networks.py new file mode 100644 index 00000000..2a684481 --- /dev/null +++ b/LiveSpeechPortraits/source_code/models/networks.py @@ -0,0 +1,873 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import lr_scheduler +from torch.nn import init +import functools + + +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence + + + +############################################################################### +# The detailed network architecture implementation for each model +############################################################################### + +class APC_encoder(nn.Module): + def __init__(self, + mel_dim, + hidden_size, + num_layers, + residual): + super(APC_encoder, self).__init__() + + input_size = mel_dim + + in_sizes = ([input_size] + [hidden_size] * (num_layers - 1)) + out_sizes = [hidden_size] * num_layers + self.rnns = nn.ModuleList( + [nn.GRU(input_size=in_size, hidden_size=out_size, batch_first=True) for (in_size, out_size) in zip(in_sizes, out_sizes)]) + + self.rnn_residual = residual + + def forward(self, inputs, lengths): + ''' + input: + inputs: (batch_size, seq_len, mel_dim) + lengths: (batch_size,) + + return: + predicted_mel: (batch_size, seq_len, mel_dim) + internal_reps: (num_layers + x, batch_size, seq_len, rnn_hidden_size), + where x is 1 if there's a prenet, otherwise 0 + ''' + with torch.no_grad(): + seq_len = inputs.size(1) + packed_rnn_inputs = pack_padded_sequence(inputs, lengths, True) + + for i, layer in enumerate(self.rnns): + packed_rnn_outputs, _ = layer(packed_rnn_inputs) + + rnn_outputs, _ = pad_packed_sequence( + packed_rnn_outputs, True, total_length=seq_len) + # outputs: (batch_size, seq_len, rnn_hidden_size) + + if i + 1 < len(self.rnns): + rnn_inputs, _ = pad_packed_sequence( + packed_rnn_inputs, True, total_length=seq_len) + # rnn_inputs: (batch_size, seq_len, rnn_hidden_size) + if self.rnn_residual and rnn_inputs.size(-1) == rnn_outputs.size(-1): + # Residual connections + rnn_outputs = rnn_outputs + rnn_inputs + packed_rnn_inputs = pack_padded_sequence(rnn_outputs, lengths, True) + + + return rnn_outputs + + + + +class WaveNet(nn.Module): + ''' This is a complete implementation of WaveNet architecture, mainly composed + of several residual blocks and some other operations. + Args: + batch_size: number of batch size + residual_layers: number of layers in each residual blocks + residual_blocks: number of residual blocks + dilation_channels: number of channels for the dilated convolution + residual_channels: number of channels for the residual connections + skip_channels: number of channels for the skip connections + end_channels: number of channels for the end convolution + classes: Number of possible values each sample can have as output + kernel_size: size of dilation convolution kernel + output_length(int): Number of samples that are generated for each input + use_bias: whether bias is used in each layer. + cond(bool): whether condition information are applied. if cond == True: + cond_channels: channel number of condition information + `` loss(str): GMM loss is adopted. `` + ''' + def __init__(self, + residual_layers = 10, + residual_blocks = 3, + dilation_channels = 32, + residual_channels = 32, + skip_channels = 256, + kernel_size = 2, + output_length = 16, + use_bias = False, + cond = True, + input_channels = 128, + ncenter = 1, + ndim = 73*2, + output_channels = 73*3, + cond_channels = 256, + activation = 'leakyrelu'): + super(WaveNet, self).__init__() + + self.layers = residual_layers + self.blocks = residual_blocks + self.dilation_channels = dilation_channels + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.input_channels = input_channels + self.ncenter = ncenter + self.ndim = ndim +# self.output_channels = (2 * self.ndim + 1) * self.ncenter + self.output_channels = output_channels + self.kernel_size = kernel_size + self.output_length = output_length + self.bias = use_bias + self.cond = cond + self.cond_channels = cond_channels + + # build modules + self.dilations = [] + self.dilation_queues = [] + residual_blocks = [] + self.receptive_field = 1 + + # 1x1 convolution to create channels + self.start_conv1 = nn.Conv1d(in_channels=self.input_channels, + out_channels=self.residual_channels, + kernel_size=1, + bias=True) + self.start_conv2 = nn.Conv1d(in_channels=self.residual_channels, + out_channels=self.residual_channels, + kernel_size=1, + bias=True) + if activation == 'relu': + self.activation = nn.ReLU(inplace = True) + elif activation == 'leakyrelu': + self.activation = nn.LeakyReLU(0.2) + self.drop_out2D = nn.Dropout2d(p=0.5) + + + # build residual blocks + for b in range(self.blocks): + new_dilation = 1 + additional_scope = kernel_size - 1 + for i in range(self.layers): + # create current residual block + residual_blocks.append(residual_block(dilation = new_dilation, + dilation_channels = self.dilation_channels, + residual_channels = self.residual_channels, + skip_channels = self.skip_channels, + kernel_size = self.kernel_size, + use_bias = self.bias, + cond = self.cond, + cond_channels = self.cond_channels)) + new_dilation *= 2 + + self.receptive_field += additional_scope + additional_scope *= 2 + + self.residual_blocks = nn.ModuleList(residual_blocks) + # end convolutions + + self.end_conv_1 = nn.Conv1d(in_channels = self.skip_channels, + out_channels = self.output_channels, + kernel_size = 1, + bias = True) + self.end_conv_2 = nn.Conv1d(in_channels = self.output_channels, + out_channels = self.output_channels, + kernel_size = 1, + bias = True) + + + def parameter_count(self): + par = list(self.parameters()) + s = sum([np.prod(list(d.size())) for d in par]) + return s + + def forward(self, input, cond=None): + ''' + Args: + input: [b, ndim, T] + cond: [b, nfeature, T] + Returns: + res: [b, T, ndim] + ''' + # dropout + x = self.drop_out2D(input) + + # preprocess + x = self.activation(self.start_conv1(x)) + x = self.activation(self.start_conv2(x)) + skip = 0 +# for i in range(self.blocks * self.layers): + for i, dilation_block in enumerate(self.residual_blocks): + x, current_skip = self.residual_blocks[i](x, cond) + skip += current_skip + + # postprocess + res = self.end_conv_1(self.activation(skip)) + res = self.end_conv_2(self.activation(res)) + + # cut the output size + res = res[:, :, -self.output_length:] # [b, ndim, T] + res = res.transpose(1, 2) # [b, T, ndim] + + return res + + + +class residual_block(nn.Module): + ''' + This is the implementation of a residual block in wavenet model. Every + residual block takes previous block's output as input. The forward pass of + each residual block can be illusatrated as below: + + ######################### Current Residual Block ########################## + # |-----------------------*residual*--------------------| # + # | | # + # | |-- dilated conv -- tanh --| | # + # -> -|-- pad--| * ---- |-- 1x1 -- + --> *input* # + # |-- dilated conv -- sigm --| | # + # 1x1 # + # | # + # ---------------------------------------------> + -------------> *skip* # + ########################################################################### + As shown above, each residual block returns two value: 'input' and 'skip': + 'input' is indeed this block's output and also is the next block's input. + 'skip' is the skip data which will be added finally to compute the prediction. + The input args own the same meaning in the WaveNet class. + + ''' + def __init__(self, + dilation, + dilation_channels = 32, + residual_channels = 32, + skip_channels = 256, + kernel_size = 2, + use_bias = False, + cond = True, + cond_channels = 128): + super(residual_block, self).__init__() + + self.dilation = dilation + self.dilation_channels = dilation_channels + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.kernel_size = kernel_size + self.bias = use_bias + self.cond = cond + self.cond_channels = cond_channels + # zero padding to the left of the sequence. + self.padding = (int((self.kernel_size - 1) * self.dilation), 0) + + # dilated convolutions + self.filter_conv= nn.Conv1d(in_channels = self.residual_channels, + out_channels = self.dilation_channels, + kernel_size = self.kernel_size, + dilation = self.dilation, + bias = self.bias) + + self.gate_conv = nn.Conv1d(in_channels = self.residual_channels, + out_channels = self.dilation_channels, + kernel_size = self.kernel_size, + dilation = self.dilation, + bias = self.bias) + + # 1x1 convolution for residual connections + self.residual_conv = nn.Conv1d(in_channels = self.dilation_channels, + out_channels = self.residual_channels, + kernel_size = 1, + bias = self.bias) + + # 1x1 convolution for skip connections + self.skip_conv = nn.Conv1d(in_channels = self.dilation_channels, + out_channels = self.skip_channels, + kernel_size = 1, + bias = self.bias) + + # condition conv, no dilation + if self.cond == True: + self.cond_filter_conv = nn.Conv1d(in_channels = self.cond_channels, + out_channels = self.dilation_channels, + kernel_size = 1, + bias = True) + self.cond_gate_conv = nn.Conv1d(in_channels = self.cond_channels, + out_channels = self.dilation_channels, + kernel_size = 1, + bias = True) + + + def forward(self, input, cond=None): + if self.cond is True and cond is None: + raise RuntimeError("set using condition to true, but no cond tensor inputed") + + x_pad = F.pad(input, self.padding) + # filter + filter = self.filter_conv(x_pad) + # gate + gate = self.gate_conv(x_pad) + + if self.cond == True and cond is not None: + filter_cond = self.cond_filter_conv(cond) + gate_cond = self.cond_gate_conv(cond) + # add cond results + filter = filter + filter_cond + gate = gate + gate_cond + + # element-wise multiple + filter = torch.tanh(filter) + gate = torch.sigmoid(gate) + x = filter * gate + + # residual and skip + residual = self.residual_conv(x) + input + skip = self.skip_conv(x) + + + return residual, skip + + + + +## 2D convolution layers +def conv2d(batch_norm, in_planes, out_planes, kernel_size=3, stride=1): + if batch_norm: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, inplace=True) + ) + else: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), + nn.LeakyReLU(0.2, inplace=True) + ) + + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], useDDP=False): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + if useDDP: + net = net().to(gpu_ids) + net = DDP(net, device_ids=gpu_ids) # DDP + print(f'use DDP to apply models on {gpu_ids}') + else: + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch - opt.n_epochs) / float(opt.n_epochs_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule, last_epoch=opt.epoch_count-2) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=opt.gamma, last_epoch=opt.epoch_count-2) + for _ in range(opt.epoch_count-2): + scheduler.step() + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1 and hasattr(m, 'weight'): + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def print_network(net): + if isinstance(net, list): + net = net[0] + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + + + + +class Feature2FaceGenerator_normal(nn.Module): + def __init__(self, input_nc=4, output_nc=3, num_downs=8, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(Feature2FaceGenerator_normal, self).__init__() + # construct unet structure + unet_block = ResUnetSkipConnectionBlock_small(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, + innermost=True) + + for i in range(num_downs - 5): + unet_block = ResUnetSkipConnectionBlock_small(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, + norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = ResUnetSkipConnectionBlock_small(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock_small(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock_small(ngf, ngf * 2, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock_small(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, + norm_layer=norm_layer) + + self.model = unet_block + + def forward(self, input): + output = self.model(input) + output = torch.tanh(output) # scale to [-1, 1] + + return output + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class ResUnetSkipConnectionBlock_small(nn.Module): + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(ResUnetSkipConnectionBlock_small, self).__init__() + self.outermost = outermost + use_bias = norm_layer == nn.InstanceNorm2d + + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, + stride=2, padding=1, bias=use_bias) + # add two resblock + res_downconv = [ResidualBlock(inner_nc, norm_layer)] + res_upconv = [ResidualBlock(outer_nc, norm_layer)] + + # res_downconv = [ResidualBlock(inner_nc)] + # res_upconv = [ResidualBlock(outer_nc)] + + downrelu = nn.ReLU(True) + uprelu = nn.ReLU(True) + if norm_layer != None: + downnorm = norm_layer(inner_nc) + upnorm = norm_layer(outer_nc) + + if outermost: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + down = [downconv, downrelu] + res_downconv + # up = [uprelu, upsample, upconv, upnorm] + up = [upsample, upconv] + model = down + [submodule] + up + elif innermost: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + down = [downconv, downrelu] + res_downconv + if norm_layer == None: + up = [upsample, upconv, uprelu] + res_upconv + else: + up = [upsample, upconv, upnorm, uprelu] + res_upconv + model = down + up + else: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + if norm_layer == None: + down = [downconv, downrelu] + res_downconv + up = [upsample, upconv, uprelu] + res_upconv + else: + down = [downconv, downnorm, downrelu] + res_downconv + up = [upsample, upconv, upnorm, uprelu] + res_upconv + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + return torch.cat([x, self.model(x)], 1) + + + +class Feature2FaceGenerator_large(nn.Module): + def __init__(self, input_nc=4, output_nc=3, num_downs=8, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(Feature2FaceGenerator_large, self).__init__() + # construct unet structure + unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, + innermost=True) + + for i in range(num_downs - 5): + unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, + norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, + norm_layer=norm_layer) + + self.model = unet_block + + def forward(self, input): + output = self.model(input) + output = torch.tanh(output) # scale to [-1, 1] + + return output + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class ResUnetSkipConnectionBlock(nn.Module): + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(ResUnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + use_bias = norm_layer == nn.InstanceNorm2d + + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, + stride=2, padding=1, bias=use_bias) + # add two resblock + res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)] + res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)] + + # res_downconv = [ResidualBlock(inner_nc)] + # res_upconv = [ResidualBlock(outer_nc)] + + downrelu = nn.ReLU(True) + uprelu = nn.ReLU(True) + if norm_layer != None: + downnorm = norm_layer(inner_nc) + upnorm = norm_layer(outer_nc) + + if outermost: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + down = [downconv, downrelu] + res_downconv + # up = [uprelu, upsample, upconv, upnorm] + up = [upsample, upconv] + model = down + [submodule] + up + elif innermost: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + down = [downconv, downrelu] + res_downconv + if norm_layer == None: + up = [upsample, upconv, uprelu] + res_upconv + else: + up = [upsample, upconv, upnorm, uprelu] + res_upconv + model = down + up + else: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + if norm_layer == None: + down = [downconv, downrelu] + res_downconv + up = [upsample, upconv, uprelu] + res_upconv + else: + down = [downconv, downnorm, downrelu] + res_downconv + up = [upsample, upconv, upnorm, uprelu] + res_upconv + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + return torch.cat([x, self.model(x)], 1) + + +# UNet with residual blocks +class ResidualBlock(nn.Module): + def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d): + super(ResidualBlock, self).__init__() + self.relu = nn.ReLU(True) + if norm_layer == None: + # hard to converge with out batch or instance norm + self.block = nn.Sequential( + nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), + ) + else: + self.block = nn.Sequential( + nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), + norm_layer(in_features) + ) + + def forward(self, x): + residual = x + out = self.block(x) + out += residual + out = self.relu(out) + return out + # return self.relu(x + self.block(x)) + + + +class Feature2FaceGenerator_Unet(nn.Module): + def __init__(self, input_nc=4, output_nc=3, num_downs=8, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(Feature2FaceGenerator_Unet, self).__init__() + + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + + def forward(self, input): + output = self.model(input) + + return output + + + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + + + +class MultiscaleDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, + num_D=3, getIntermFeat=False): + super(MultiscaleDiscriminator, self).__init__() + self.num_D = num_D + self.n_layers = n_layers + self.getIntermFeat = getIntermFeat + ndf_max = 64 + + for i in range(num_D): + netD = NLayerDiscriminator(input_nc, min(ndf_max, ndf*(2**(num_D-1-i))), n_layers, getIntermFeat) + if getIntermFeat: + for j in range(n_layers+2): + setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) + else: + setattr(self, 'layer'+str(i), netD.model) + + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + + def singleD_forward(self, model, input): + if self.getIntermFeat: + result = [input] + for i in range(len(model)): + result.append(model[i](result[-1])) + return result[1:] + else: + return [model(input)] + + def forward(self, input): + num_D = self.num_D + result = [] + input_downsampled = input + for i in range(num_D): + if self.getIntermFeat: + model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] + else: + model = getattr(self, 'layer'+str(num_D-1-i)) + result.append(self.singleD_forward(model, input_downsampled)) + if i != (num_D-1): + input_downsampled = self.downsample(input_downsampled) + return result + + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, getIntermFeat=False): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + nn.BatchNorm2d(nf), + nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + nn.BatchNorm2d(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[1:] + else: + return self.model(input) + + + + + + diff --git a/LiveSpeechPortraits/source_code/options/__init__.py b/LiveSpeechPortraits/source_code/options/__init__.py new file mode 100644 index 00000000..e7eedebe --- /dev/null +++ b/LiveSpeechPortraits/source_code/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/LiveSpeechPortraits/source_code/options/base_options_audio2feature.py b/LiveSpeechPortraits/source_code/options/base_options_audio2feature.py new file mode 100644 index 00000000..39835de8 --- /dev/null +++ b/LiveSpeechPortraits/source_code/options/base_options_audio2feature.py @@ -0,0 +1,185 @@ +import argparse +import os +from util import util +import torch +import numpy as np +import models + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + ## task + parser.add_argument('--task', type=str, default='Audio2Feature', help='|Audio2Feature|Feature2Face|etc.') + + + ## basic parameters + parser.add_argument('--model', type=str, default='audio2feature', help='trained model') + parser.add_argument('--dataset_mode', type=str, default='audiovisual', help='chooses how datasets are loaded. [unaligned | aligned | single]') + parser.add_argument('--name', type=str, default='Audio2Feature', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/', help='models are saved here') + + + # dataset parameters + parser.add_argument('--dataset_names', type=str, default='default_name') + parser.add_argument('--dataroot', type=str, default='default_path') + parser.add_argument('--frame_jump_stride', type=int, default=4, help='jump index in audio dataset.') + parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data') + parser.add_argument('--batch_size', type=int, default=32, help='input batch size') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--audio_encoder', type=str, default='APC', help='|CNN|LSTM|APC|NPC|') + parser.add_argument('--feature_decoder', type=str, default='LSTM', help='|WaveNet|LSTM|') + parser.add_argument('--loss', type=str, default='L2', help='|GMM|L2|') + parser.add_argument('--A2L_GMM_ndim', type=int, default=25*3) + parser.add_argument('--sequence_length', type=int, default=240, help='length of training frames in each iteration') + + + # data setting parameters + parser.add_argument('--FPS', type=str, default=60, help='video fps') + parser.add_argument('--sample_rate', type=int, default=16000, help='audio sample rate') + parser.add_argument('--audioRF_history', type=int, default=60, help='audio history receptive field length') + parser.add_argument('--audioRF_future', type=int, default=0, help='audio future receptive field length') + parser.add_argument('--feature_dtype', type=str, default='pts3d', help='|FW|pts3d|') + parser.add_argument('--ispts_norm', type=int, default=1, help='use normalized 3d points.') + parser.add_argument('--use_delta_pts', type=int, default=1, help='whether use delta landmark representation') + parser.add_argument('--frame_future', type=int, default=18) + parser.add_argument('--predict_length', type=int, default=1) + parser.add_argument('--only_mouth', type=int, default=1) + + + # APC parameters + parser.add_argument('--APC_hidden_size', type=int, default=512) + parser.add_argument('--APC_rnn_layers', type=int, default=3) + parser.add_argument("--APC_residual", action="store_true") + parser.add_argument('--APC_frame_history', type=int, default=0) + + + # LSTM parameters + parser.add_argument('--LSTM_hidden_size', type=int, default=256) + parser.add_argument('--LSTM_output_size', type=int, default=80) + parser.add_argument('--LSTM_layers', type=int, default=3) + parser.add_argument('--LSTM_dropout', type=float, default=0) + parser.add_argument("--LSTM_residual", action="store_true") + parser.add_argument('--LSTM_sequence_length', type=int, default=60) + + + # additional parameters + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + + self.initialized = True + return parser + + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + print('opt:', opt) + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + # save and return the parser + self.parser = parser + return opt + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + if opt.isTrain: + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + # if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(opt.gpu_ids[0]) + + # set datasets + if self.isTrain: + opt.train_dataset_names = np.loadtxt(os.path.join(opt.dataroot, + opt.dataset_names, + opt.train_dataset_names), dtype=np.str).tolist() + if type(opt.train_dataset_names) == str: + opt.train_dataset_names = [opt.train_dataset_names] + opt.validate_dataset_names = np.loadtxt(os.path.join(opt.dataroot, + opt.dataset_names, + opt.validate_dataset_names), dtype=np.str).tolist() + if type(opt.validate_dataset_names) == str: + opt.validate_dataset_names = [opt.validate_dataset_names] + + self.opt = opt + return self.opt + + + + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/options/base_options_audio2headpose.py b/LiveSpeechPortraits/source_code/options/base_options_audio2headpose.py new file mode 100644 index 00000000..41c41718 --- /dev/null +++ b/LiveSpeechPortraits/source_code/options/base_options_audio2headpose.py @@ -0,0 +1,189 @@ +import argparse +import os +from util import util +import torch +import models +import numpy as np + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + ## task + parser.add_argument('--task', type=str, default='Audio2Headpose', help='|Audio2Feature|Feature2Face|Full|') + + + ## basic parameters + parser.add_argument('--model', type=str, default='audio2headpose', help='trained model') + parser.add_argument('--dataset_mode', type=str, default='audiovisual', help='chooses how datasets are loaded. [unaligned | aligned | single]') + parser.add_argument('--name', type=str, default='Audio2Headpose', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/', help='models are saved here') + + + # data parameters + parser.add_argument('--FPS', type=str, default=60, help='video fps') + parser.add_argument('--sample_rate', type=int, default=16000, help='audio sample rate') + parser.add_argument('--audioRF_history', type=int, default=60, help='audio history receptive field length') + parser.add_argument('--audioRF_future', type=int, default=0, help='audio future receptive field length') + parser.add_argument('--feature_decoder', type=str, default='WaveNet', help='|WaveNet|LSTM|') + parser.add_argument('--loss', type=str, default='GMM', help='|GMM|L2|') + + + # dataset parameters + parser.add_argument('--dataset_names', type=str, default='name', help='chooses how datasets are loaded.') + parser.add_argument('--dataroot', type=str, default='path') + parser.add_argument('--frame_jump_stride', type=int, default=1, help='jump index in audio dataset.') + parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data') + parser.add_argument('--batch_size', type=int, default=32, help='input batch size') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--audio_encoder', type=str, default='APC', help='|CNN|LSTM|APC|') + parser.add_argument('--audiofeature_input_channels', type=int, default=80, help='input channels of audio features') + parser.add_argument('--frame_future', type=int, default=15) + parser.add_argument('--predict_length', type=int, default=5) + parser.add_argument('--audio_windows', type=int, default=2) + parser.add_argument('--time_frame_length', type=int, default=240, help='length of training frames in each iteration') + + + # APC parameters + parser.add_argument('--APC_hidden_size', type=int, default=512) + parser.add_argument('--APC_rnn_layers', type=int, default=3) + parser.add_argument("--APC_residual", action="store_true") + parser.add_argument('--APC_frame_history', type=int, default=60) + + + ## network parameters + # audio2headpose wavenet + parser.add_argument('--A2H_wavenet_residual_layers', type=int, default=7, help='residual layer numbers') + parser.add_argument('--A2H_wavenet_residual_blocks', type=int, default=2, help='residual block numbers') + parser.add_argument('--A2H_wavenet_dilation_channels', type=int, default=128, help='dilation convolution channels') + parser.add_argument('--A2H_wavenet_residual_channels', type=int, default=128, help='residual channels') + parser.add_argument('--A2H_wavenet_skip_channels', type=int, default=256, help='skip channels') + parser.add_argument('--A2H_wavenet_kernel_size', type=int, default=2, help='dilation convolution kernel size') + parser.add_argument('--A2H_wavenet_use_bias', type=bool, default=True, help='whether to use bias in dilation convolution') + parser.add_argument('--A2H_wavenet_cond', type=bool, default=True, help='whether use condition input') + parser.add_argument('--A2H_wavenet_cond_channels', type=int, default=512, help='whether use condition input') + parser.add_argument('--A2H_wavenet_input_channels', type=int, default=12, help='input channels') + parser.add_argument('--A2H_GMM_ncenter', type=int, default=1, help='gaussian distribution numbers, 1 for single gaussian distribution') + parser.add_argument('--A2H_GMM_ndim', type=int, default=12, help='dimension of each gaussian, usually number of pts') + parser.add_argument('--A2H_GMM_sigma_min', type=float, default=0.03, help='minimal gaussian sigma values') + + + # additional parameters + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + parser.add_argument('--sequence_length', type=int, default=240, help='length of training frames in each iteration') + + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + + # save and return the parser + self.parser = parser + return opt + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + if opt.isTrain: + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + # if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(opt.gpu_ids[0]) + + # set datasets + if self.isTrain: + opt.train_dataset_names = np.loadtxt(os.path.join(opt.dataroot, + opt.dataset_names, + opt.train_dataset_names), dtype=np.str).tolist() + if type(opt.train_dataset_names) == str: + opt.train_dataset_names = [opt.train_dataset_names] + opt.validate_dataset_names = np.loadtxt(os.path.join(opt.dataroot, + opt.dataset_names, + opt.validate_dataset_names), dtype=np.str).tolist() + if type(opt.validate_dataset_names) == str: + opt.validate_dataset_names = [opt.validate_dataset_names] + + self.opt = opt + return self.opt + + + + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/options/base_options_feature2face.py b/LiveSpeechPortraits/source_code/options/base_options_feature2face.py new file mode 100644 index 00000000..26655cec --- /dev/null +++ b/LiveSpeechPortraits/source_code/options/base_options_feature2face.py @@ -0,0 +1,129 @@ +import argparse +import os +from util import util +import torch +import numpy as np + +class BaseOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + self.initialized = False + + def initialize(self): + ## task + self.parser.add_argument('--task', type=str, default='Feature2Face', help='|Audio2Feature|Feature2Face|Full|') + self.parser.add_argument('--model', type=str, default='feature2face', help='chooses which model to use. vid2vid, test') + self.parser.add_argument('--name', type=str, default='TestRender', help='name of the experiment. It decides where to store samples and models') + self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/', help='models are saved here') + self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + + + # display + self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') + self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display') + self.parser.add_argument('--tf_log', default=True, action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') + + + # input/output size + self.parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + self.parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size') + self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') + self.parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels') + self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') + + + # setting inputs + self.parser.add_argument('--dataset_mode', type=str, default='face', help='chooses how datasets are loaded.') + self.parser.add_argument('--dataroot', type=str, default='./data/') + self.parser.add_argument('--isH5', type=int, default=1, help='whether to use h5py to save dataset') + self.parser.add_argument('--suffix', type=str, default='.jpg', help='image suffix') + self.parser.add_argument('--isMask', type=int, default=0, help='use face mask') + self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + self.parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data') + self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + self.parser.add_argument('--resize_or_crop', type=str, default='scaleWidth', help='scaling and cropping of images at load time [resize_and_crop|crop|scaledCrop|scaleWidth|scaleWidth_and_crop|scaleWidth_and_scaledCrop|scaleHeight|scaleHeight_and_crop] etc') + self.parser.add_argument('--no_flip', type=int, default=1, help='if specified, do not flip the images for data argumentation') + + + # generator arch + self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') + self.parser.add_argument('--n_downsample_G', type=int, default=8, help='number of downsampling layers in netG') + self.parser.add_argument('--ngf_E', type=int, default=16, help='# of gen filters in first conv layer') + self.parser.add_argument('--n_downsample_E', type=int, default=3, help='number of downsampling layers in Enhancement') + self.parser.add_argument('--n_blocks_E', type=int, default=3, help='number of resnet blocks in Enhancement') + + + # miscellaneous + self.parser.add_argument('--load_pretrain', type=str, default='', help='if specified, load the pretrained model') + self.parser.add_argument('--debug', action='store_true', help='if specified, use small dataset for debug') + self.parser.add_argument('--fp16', type=int, default=0, help='train with AMP') + self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training') + self.parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + + self.initialized = True + + def parse_str(self, ids): + str_ids = ids.split(',') + ids_list = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + ids_list.append(id) + return ids_list + + def parse(self, save=True): + if not self.initialized: + self.initialize() + self.opt, _ = self.parser.parse_known_args() + self.opt.isTrain = self.isTrain # train or test + + self.opt.gpu_ids = self.parse_str(self.opt.gpu_ids) + + # set gpu ids + # if len(self.opt.gpu_ids) > 0: + # torch.cuda.set_device(self.opt.gpu_ids[0]) + + # set datasets + datasets = self.opt.dataset_names.split(',') + self.opt.dataset_names = [] + for name in datasets: + self.opt.dataset_names.append(name) + + if self.isTrain: + self.opt.train_dataset_names = np.loadtxt(os.path.join(self.opt.dataroot, + self.opt.dataset_names[0], + self.opt.train_dataset_names), dtype=np.str).tolist() + if type(self.opt.train_dataset_names) == str: + self.opt.train_dataset_names = [self.opt.train_dataset_names] + self.opt.validate_dataset_names = np.loadtxt(os.path.join(self.opt.dataroot, + self.opt.dataset_names[0], + self.opt.validate_dataset_names), dtype=np.str).tolist() + if type(self.opt.validate_dataset_names) == str: + self.opt.validate_dataset_names = [self.opt.validate_dataset_names] + + else: + test_datasets = self.opt.test_dataset_names.split(',') + self.opt.test_dataset_names = [] + for name in test_datasets: + self.opt.test_dataset_names.append(name) + + + args = vars(self.opt) + + print('------------ Options -------------') + for k, v in sorted(args.items()): + print('%s: %s' % (str(k), str(v))) + print('-------------- End ----------------') + + if self.isTrain: + # save to the disk + expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + util.mkdirs(expr_dir) + if save: + file_name = os.path.join(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write('------------ Options -------------\n') + for k, v in sorted(args.items()): + opt_file.write('%s: %s\n' % (str(k), str(v))) + opt_file.write('-------------- End ----------------\n') + return self.opt diff --git a/LiveSpeechPortraits/source_code/options/test_audio2feature_options.py b/LiveSpeechPortraits/source_code/options/test_audio2feature_options.py new file mode 100644 index 00000000..2478db2d --- /dev/null +++ b/LiveSpeechPortraits/source_code/options/test_audio2feature_options.py @@ -0,0 +1,20 @@ +from .base_options_audio2feature import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--load_epoch', type=str, default='500', help='which epoch to load? set to latest to use latest cached model') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + # rewrite devalue values + parser.set_defaults(time_frame_length=1) + self.isTrain = False + + return parser diff --git a/LiveSpeechPortraits/source_code/options/test_audio2headpose_options.py b/LiveSpeechPortraits/source_code/options/test_audio2headpose_options.py new file mode 100644 index 00000000..6e32746f --- /dev/null +++ b/LiveSpeechPortraits/source_code/options/test_audio2headpose_options.py @@ -0,0 +1,20 @@ +from .base_options_audio2headpose import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--load_epoch', type=str, default='500', help='which epoch to load? set to latest to use latest cached model') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + # rewrite devalue values + parser.set_defaults(time_frame_length=1) + self.isTrain = False + + return parser \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/options/test_feature2face_options.py b/LiveSpeechPortraits/source_code/options/test_feature2face_options.py new file mode 100644 index 00000000..63e2ea6c --- /dev/null +++ b/LiveSpeechPortraits/source_code/options/test_feature2face_options.py @@ -0,0 +1,11 @@ +from .base_options_feature2face import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + self.parser.add_argument('--dataset_names', type=str, default='name', help='chooses test datasets.') + self.parser.add_argument('--test_dataset_names', type=str, default='name', help='chooses validation datasets.') + + self.isTrain = False diff --git a/LiveSpeechPortraits/source_code/options/train_audio2feature_options.py b/LiveSpeechPortraits/source_code/options/train_audio2feature_options.py new file mode 100644 index 00000000..0491e38f --- /dev/null +++ b/LiveSpeechPortraits/source_code/options/train_audio2feature_options.py @@ -0,0 +1,56 @@ +from .base_options_audio2feature import BaseOptions + + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + + + # network saving and loading parameters + parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', default=False, action='store_true', help='continue training: load the latest model') + parser.add_argument('--load_epoch', type=str, default='200', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--epoch_count', type=int, default=0, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--re_transform', type=int, default=0, help='re-transform landmarks') + + + # training parameters + parser.add_argument('--train_dataset_names', type=str, default='train_list.txt', help='chooses validation datasets.') + parser.add_argument('--validate_dataset_names', type=str, default='val_list.txt', help='chooses validation datasets.') + parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate for adam') + parser.add_argument('--gamma', type=float, default=0.2, help='step learning rate gamma') + parser.add_argument('--lr_decay_iters', type=int, default=250, help='multiply by a gamma every lr_decay_iters iterations') + parser.add_argument('--n_epochs_decay', type=int, default=250, help='number of epochs to linearly decay learning rate to zero') + parser.add_argument('--validate_epoch', type=int, default=50, help='validate model every some epochs, 0 for not validate during training') + parser.add_argument('--loss_smooth_weight', type=float, default=0, help='smooth loss weight, 0 for not use smooth loss') + parser.add_argument('--optimizer', type=str, default='AdamW', help='Adam, AdamW, RMSprop') + + + # data augmentations + parser.add_argument('--gaussian_noise', type=int, default=1, help='whether add gaussian noise to input & groundtruth features') + parser.add_argument('--gaussian_noise_scale', type=float, default=0.01, help='gaussian noise scale') + + + self.isTrain = True + return parser + + + + + + + + + + + + diff --git a/LiveSpeechPortraits/source_code/options/train_audio2headpose_options.py b/LiveSpeechPortraits/source_code/options/train_audio2headpose_options.py new file mode 100644 index 00000000..27339c81 --- /dev/null +++ b/LiveSpeechPortraits/source_code/options/train_audio2headpose_options.py @@ -0,0 +1,45 @@ +from .base_options_audio2headpose import BaseOptions + + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + + + # network saving and loading parameters + parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', default=False, action='store_true', help='continue training: load the latest model') + parser.add_argument('--load_epoch', type=str, default='0', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--epoch_count', type=int, default=0, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--re_transform', type=int, default=0, help='re-transform landmarks') + + + # training parameters + parser.add_argument('--smooth_loss', type=int, default=0, help='use smooth loss weight, 0 for not use') + parser.add_argument('--train_dataset_names', type=str, default='train_list.txt', help='chooses validation datasets.') + parser.add_argument('--validate_dataset_names', type=str, default='val_list.txt', help='chooses validation datasets.') + parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate for adam') + parser.add_argument('--gamma', type=float, default=0.2, help='step learning rate gamma') + parser.add_argument('--lr_decay_iters', type=int, default=250, help='multiply by a gamma every lr_decay_iters iterations') + parser.add_argument('--n_epochs_decay', type=int, default=250, help='number of epochs to linearly decay learning rate to zero') + parser.add_argument('--validate_epoch', type=int, default=50, help='validate model every some epochs, 0 for not validate during training') + parser.add_argument('--loss_smooth_weight', type=float, default=0, help='smooth loss weight, 0 for not use smooth loss') + parser.add_argument('--optimizer', type=str, default='AdamW', help='Adam, AdamW, RMSprop') + + + # data augmentations + parser.add_argument('--gaussian_noise', type=int, default=1, help='whether add gaussian noise to input & groundtruth features') + parser.add_argument('--gaussian_noise_scale', type=float, default=0.01, help='gaussian noise scale') + + + self.isTrain = True + return parser \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/options/train_feature2face_options.py b/LiveSpeechPortraits/source_code/options/train_feature2face_options.py new file mode 100644 index 00000000..25516c8d --- /dev/null +++ b/LiveSpeechPortraits/source_code/options/train_feature2face_options.py @@ -0,0 +1,63 @@ +from .base_options_feature2face import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + ## dataset settings + self.parser.add_argument('--dataset_names', type=str, default='name', help='chooses how datasets are loaded.') + self.parser.add_argument('--train_dataset_names', type=str, default='train_list.txt') + self.parser.add_argument('--validate_dataset_names', type=str, default='val_list.txt') + + + ## training flags + self.parser.add_argument('--display_freq', type=int, default=10, help='frequency of showing training results on screen(iterations)') + self.parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console(epochs)') + self.parser.add_argument('--save_latest_freq', type=int, default=100, help='frequency of to save the latest results(iterations)') + self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + self.parser.add_argument('--continue_train', default=True, action='store_true', help='continue training: load the latest model') + self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + self.parser.add_argument('--load_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--n_epochs_warm_up', type=int, default=5, help='number of epochs warm up') + self.parser.add_argument('--n_epochs', type=int, default=10, help='number of epochs') + self.parser.add_argument('--n_epochs_decay', type=int, default=10, help='number of epochs to linearly decay learning rate to zero') + self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + self.parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') + self.parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') + self.parser.add_argument('--lr_decay_iters', type=int, default=900, help='multiply by a gamma every lr_decay_iters iterations') + self.parser.add_argument('--lr_decay_gamma', type=float, default=0.25, help='multiply by a gamma every lr_decay_iters iterations') + self.parser.add_argument('--TTUR', action='store_true', help='Use TTUR training scheme') + self.parser.add_argument('--gan_mode', type=str, default='ls', help='(ls|original|hinge)') + self.parser.add_argument('--pool_size', type=int, default=1, help='the size of image buffer that stores previously generated images') + self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + self.parser.add_argument('--frame_jump', type=int, default=1, help='jump frame for training, 1 for not jump') + self.parser.add_argument('--epoch_count', type=int, default=0, help='the starting epoch count, we save the model by , +, ...') + self.parser.add_argument('--seq_max_len', type=int, default=120, help='maximum sequence clip frames sent to network per iteration') + + + # for discriminators + self.parser.add_argument('--no_discriminator', type=int, default=0, help='not use discriminator') + self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') + self.parser.add_argument('--num_D', type=int, default=2, help='number of patch scales in each discriminator') + self.parser.add_argument('--n_layers_D', type=int, default=3, help='number of layers in discriminator') + self.parser.add_argument('--no_vgg', action='store_true', help='do not use VGG feature matching loss') + self.parser.add_argument('--no_ganFeat', action='store_true', help='do not match discriminator features') + self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching') + self.parser.add_argument('--sparse_D', action='store_true', help='use sparse temporal discriminators to save memory') + + + # for temporal + self.parser.add_argument('--lambda_T', type=float, default=10.0, help='weight for temporal loss') + self.parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for temporal loss') + self.parser.add_argument('--lambda_F', type=float, default=10.0, help='weight for flow loss') + self.parser.add_argument('--lambda_mask', type=float, default=500.0, help='weight for mask l1 loss') + self.parser.add_argument('--n_frames_D', type=int, default=3, help='number of frames to feed into temporal discriminator') + self.parser.add_argument('--n_scales_temporal', type=int, default=2, help='number of temporal scales in the temporal discriminator') + self.parser.add_argument('--n_frames_per_gpu', type=int, default=1, help='the number of frames to load into one GPU at a time. only 1 is supported now') + self.parser.add_argument('--max_frames_backpropagate', type=int, default=1, help='max number of frames to backpropagate') + self.parser.add_argument('--max_t_step', type=int, default=1, help='max spacing between neighboring sampled frames. If greater than 1, the network may randomly skip frames during training.') + self.parser.add_argument('--n_frames_total', type=int, default=12, help='the overall number of frames in a sequence to train with') + self.parser.add_argument('--nepochs_step', type=int, default=5, help='how many epochs do we change training sequence length again') + self.parser.add_argument('--nepochs_fix_global', type=int, default=0, help='if specified, only train the finest spatial layer for the given iterations') + + self.isTrain = True diff --git a/LiveSpeechPortraits/source_code/predict.py b/LiveSpeechPortraits/source_code/predict.py new file mode 100644 index 00000000..084b8dca --- /dev/null +++ b/LiveSpeechPortraits/source_code/predict.py @@ -0,0 +1,308 @@ +import os +import subprocess +from os.path import join +import yaml +import tempfile +import argparse +from skimage.io import imread +import numpy as np +import librosa +from util import util +from tqdm import tqdm +import torch +from collections import OrderedDict +import cv2 +from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip +from cog import BasePredictor, Input, Path +import scipy.io as sio +import albumentations as A +from options.test_audio2feature_options import TestOptions as FeatureOptions +from options.test_audio2headpose_options import TestOptions as HeadposeOptions +from options.test_feature2face_options import TestOptions as RenderOptions +from datasets import create_dataset +from models import create_model +from models.networks import APC_encoder +from util.visualizer import Visualizer +from funcs import utils, audio_funcs +from demo import write_video_with_audio +import warnings + +warnings.filterwarnings("ignore") + + +class Predictor(BasePredictor): + def setup(self): + self.parser = argparse.ArgumentParser() + self.parser.add_argument('--id', default='May', help="person name, e.g. Obama1, Obama2, May, Nadella, McStay") + self.parser.add_argument('--driving_audio', default='data/Input/00083.wav', help="path to driving audio") + self.parser.add_argument('--save_intermediates', default=0, help="whether to save intermediate results") + + def predict(self, + driving_audio: Path = Input(description='driving audio, if the file is more than 20 seconds, only the first 20 seconds will be processed for video generation'), + talking_head: str = Input(description="choose a talking head", choices=['May', 'Obama1', 'Obama2', 'Nadella', 'McStay'], default='May') + ) -> Path: + + ############################### I/O Settings ############################## + # load config files + opt = self.parser.parse_args('') + opt.driving_audio = str(driving_audio) + opt.id = talking_head + with open(join('config', opt.id + '.yaml')) as f: + config = yaml.safe_load(f) + data_root = join('data', opt.id) + + ############################ Hyper Parameters ############################# + h, w, sr, FPS = 512, 512, 16000, 60 + mouth_indices = np.concatenate([np.arange(4, 11), np.arange(46, 64)]) + eye_brow_indices = [27, 65, 28, 68, 29, 67, 30, 66, 31, 72, 32, 69, 33, 70, 34, 71] + eye_brow_indices = np.array(eye_brow_indices, np.int32) + + ############################ Pre-defined Data ############################# + mean_pts3d = np.load(join(data_root, 'mean_pts3d.npy')) + fit_data = np.load(config['dataset_params']['fit_data_path']) + pts3d = np.load(config['dataset_params']['pts3d_path']) - mean_pts3d + trans = fit_data['trans'][:, :, 0].astype(np.float32) + mean_translation = trans.mean(axis=0) + candidate_eye_brow = pts3d[10:, eye_brow_indices] + std_mean_pts3d = np.load(config['dataset_params']['pts3d_path']).mean(axis=0) + # candidates images + img_candidates = [] + for j in range(4): + output = imread(join(data_root, 'candidates', f'normalized_full_{j}.jpg')) + output = A.pytorch.transforms.ToTensor(normalize={'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5)})(image=output)['image'] + img_candidates.append(output) + img_candidates = torch.cat(img_candidates).unsqueeze(0).cuda() + + # shoulders + shoulders = np.load(join(data_root, 'normalized_shoulder_points.npy')) + shoulder3D = np.load(join(data_root, 'shoulder_points3D.npy'))[1] + ref_trans = trans[1] + + # camera matrix, we always use training set intrinsic parameters. + camera = utils.camera() + camera_intrinsic = np.load(join(data_root, 'camera_intrinsic.npy')).astype(np.float32) + APC_feat_database = np.load(join(data_root, 'APC_feature_base.npy')) + + # load reconstruction data + scale = sio.loadmat(join(data_root, 'id_scale.mat'))['scale'][0, 0] + Audio2Mel_torch = audio_funcs.Audio2Mel(n_fft=512, hop_length=int(16000 / 120), win_length=int(16000 / 60), + sampling_rate=16000, + n_mel_channels=80, mel_fmin=90, mel_fmax=7600.0).cuda() + + ########################### Experiment Settings ########################### + #### user config + use_LLE = config['model_params']['APC']['use_LLE'] + Knear = config['model_params']['APC']['Knear'] + LLE_percent = config['model_params']['APC']['LLE_percent'] + headpose_sigma = config['model_params']['Headpose']['sigma'] + Feat_smooth_sigma = config['model_params']['Audio2Mouth']['smooth'] + Head_smooth_sigma = config['model_params']['Headpose']['smooth'] + Feat_center_smooth_sigma, Head_center_smooth_sigma = 0, 0 + AMP_method = config['model_params']['Audio2Mouth']['AMP'][0] + Feat_AMPs = config['model_params']['Audio2Mouth']['AMP'][1:] + rot_AMP, trans_AMP = config['model_params']['Headpose']['AMP'] + shoulder_AMP = config['model_params']['Headpose']['shoulder_AMP'] + save_feature_maps = config['model_params']['Image2Image']['save_input'] + + #### common settings + Featopt = FeatureOptions().parse() + Headopt = HeadposeOptions().parse() + Renderopt = RenderOptions().parse() + Featopt.load_epoch = config['model_params']['Audio2Mouth']['ckp_path'] + Headopt.load_epoch = config['model_params']['Headpose']['ckp_path'] + Renderopt.dataroot = config['dataset_params']['root'] + Renderopt.load_epoch = config['model_params']['Image2Image']['ckp_path'] + Renderopt.size = config['model_params']['Image2Image']['size'] + + ############################# Load Models ################################# + print('---------- Loading Model: APC-------------') + APC_model = APC_encoder(config['model_params']['APC']['mel_dim'], + config['model_params']['APC']['hidden_size'], + config['model_params']['APC']['num_layers'], + config['model_params']['APC']['residual']) + # load all 5 here? + APC_model.load_state_dict(torch.load(config['model_params']['APC']['ckp_path']), strict=False) + APC_model.cuda() + APC_model.eval() + print('---------- Loading Model: {} -------------'.format(Featopt.task)) + Audio2Feature = create_model(Featopt) + Audio2Feature.setup(Featopt) + Audio2Feature.eval() + print('---------- Loading Model: {} -------------'.format(Headopt.task)) + Audio2Headpose = create_model(Headopt) + Audio2Headpose.setup(Headopt) + Audio2Headpose.eval() + if Headopt.feature_decoder == 'WaveNet': + Headopt.A2H_receptive_field = Audio2Headpose.Audio2Headpose.module.WaveNet.receptive_field + print('---------- Loading Model: {} -------------'.format(Renderopt.task)) + facedataset = create_dataset(Renderopt) + Feature2Face = create_model(Renderopt) + Feature2Face.setup(Renderopt) + Feature2Face.eval() + visualizer = Visualizer(Renderopt) + + # check audio duration and trim audio + extension_name = os.path.basename(opt.driving_audio).split('.')[-1] + audio_threshold = 10 + duration = librosa.get_duration(filename=opt.driving_audio) + if duration > audio_threshold: + print(f'audio file is longer than {audio_threshold} seconds, trimming the first {audio_threshold} seconds ' + f'for further processing') + ffmpeg_extract_subclip(opt.driving_audio, 0, audio_threshold, targetname=f'shorter_input.{extension_name}') + opt.driving_audio = f'shorter_input.{extension_name}' + + # create the results folder + audio_name = os.path.basename(opt.driving_audio).split('.')[0] + save_root = join('results', opt.id, audio_name) + os.makedirs(save_root, exist_ok=True) + clean_folder(save_root) + out_path = Path(tempfile.mkdtemp()) / "out.mp4" + + ############################## Inference ################################## + print('Processing audio: {} ...'.format(audio_name)) + # read audio + audio, _ = librosa.load(opt.driving_audio, sr=sr) + total_frames = np.int32(audio.shape[0] / sr * FPS) + + #### 1. compute APC features + print('1. Computing APC features...') + mel80 = utils.compute_mel_one_sequence(audio) + mel_nframe = mel80.shape[0] + with torch.no_grad(): + length = torch.Tensor([mel_nframe]) + mel80_torch = torch.from_numpy(mel80.astype(np.float32)).cuda().unsqueeze(0) + hidden_reps = APC_model.forward(mel80_torch, length)[0] # [mel_nframe, 512] + hidden_reps = hidden_reps.cpu().numpy() + audio_feats = hidden_reps + + #### 2. manifold projection + if use_LLE: + print('2. Manifold projection...') + ind = utils.KNN_with_torch(audio_feats, APC_feat_database, K=Knear) + weights, feat_fuse = utils.compute_LLE_projection_all_frame(audio_feats, APC_feat_database, ind, + audio_feats.shape[0]) + audio_feats = audio_feats * (1 - LLE_percent) + feat_fuse * LLE_percent + + #### 3. Audio2Mouth + print('3. Audio2Mouth inference...') + pred_Feat = Audio2Feature.generate_sequences(audio_feats, sr, FPS, fill_zero=True, opt=Featopt) + + #### 4. Audio2Headpose + print('4. Headpose inference...') + # set history headposes as zero + pre_headpose = np.zeros(Headopt.A2H_wavenet_input_channels, np.float32) + pred_Head = Audio2Headpose.generate_sequences(audio_feats, pre_headpose, fill_zero=True, sigma_scale=0.3, + opt=Headopt) + + #### 5. Post-Processing + print('5. Post-processing...') + nframe = min(pred_Feat.shape[0], pred_Head.shape[0]) + pred_pts3d = np.zeros([nframe, 73, 3]) + pred_pts3d[:, mouth_indices] = pred_Feat.reshape(-1, 25, 3)[:nframe] + + ## mouth + pred_pts3d = utils.landmark_smooth_3d(pred_pts3d, Feat_smooth_sigma, area='only_mouth') + pred_pts3d = utils.mouth_pts_AMP(pred_pts3d, True, AMP_method, Feat_AMPs) + pred_pts3d = pred_pts3d + mean_pts3d + pred_pts3d = utils.solve_intersect_mouth(pred_pts3d) # solve intersect lips if exist + + ## headpose + pred_Head[:, 0:3] *= rot_AMP + pred_Head[:, 3:6] *= trans_AMP + pred_headpose = utils.headpose_smooth(pred_Head[:, :6], Head_smooth_sigma).astype(np.float32) + pred_headpose[:, 3:] += mean_translation + pred_headpose[:, 0] += 180 + + ## compute projected landmarks + pred_landmarks = np.zeros([nframe, 73, 2], dtype=np.float32) + final_pts3d = np.zeros([nframe, 73, 3], dtype=np.float32) + final_pts3d[:] = std_mean_pts3d.copy() + final_pts3d[:, 46:64] = pred_pts3d[:nframe, 46:64] + for k in tqdm(range(nframe)): + ind = k % candidate_eye_brow.shape[0] + final_pts3d[k, eye_brow_indices] = candidate_eye_brow[ind] + mean_pts3d[eye_brow_indices] + pred_landmarks[k], _, _ = utils.project_landmarks(camera_intrinsic, camera.relative_rotation, + camera.relative_translation, scale, + pred_headpose[k], final_pts3d[k]) + + ## Upper Body Motion + pred_shoulders = np.zeros([nframe, 18, 2], dtype=np.float32) + pred_shoulders3D = np.zeros([nframe, 18, 3], dtype=np.float32) + for k in range(nframe): + diff_trans = pred_headpose[k][3:] - ref_trans + pred_shoulders3D[k] = shoulder3D + diff_trans * shoulder_AMP + # project + project = camera_intrinsic.dot(pred_shoulders3D[k].T) + project[:2, :] /= project[2, :] # divide z + pred_shoulders[k] = project[:2, :].T + + #### 6. Image2Image translation & Save resuls + print('6. Image2Image translation & Saving results...') + for ind in tqdm(range(0, nframe), desc='Image2Image translation inference'): + # feature_map: [input_nc, h, w] + current_pred_feature_map = facedataset.dataset.get_data_test_mode(pred_landmarks[ind], + pred_shoulders[ind], + facedataset.dataset.image_pad) + input_feature_maps = current_pred_feature_map.unsqueeze(0).cuda() + pred_fake = Feature2Face.inference(input_feature_maps, img_candidates) + # save results + visual_list = [('pred', util.tensor2im(pred_fake[0]))] + if save_feature_maps: + visual_list += [('input', np.uint8(current_pred_feature_map[0].cpu().numpy() * 255))] + visuals = OrderedDict(visual_list) + visualizer.save_images(save_root, visuals, str(ind + 1)) + + ## make videos + # generate corresponding audio, reused for all results + tmp_audio_path = join(save_root, 'tmp.wav') + tmp_audio_clip = audio[: np.int32(nframe * sr / FPS)] + librosa.output.write_wav(tmp_audio_path, tmp_audio_clip, sr) + + def write_video_with_audio(audio_path, output_path, prefix='pred_'): + fps, fourcc = 60, cv2.VideoWriter_fourcc(*'DIVX') + video_tmp_path = join(save_root, 'tmp.avi') + out = cv2.VideoWriter(video_tmp_path, fourcc, fps, (Renderopt.loadSize, Renderopt.loadSize)) + for j in tqdm(range(nframe), position=0, desc='writing video'): + img = cv2.imread(join(save_root, prefix + str(j + 1) + '.jpg')) + out.write(img) + out.release() + cmd = 'ffmpeg -i "' + video_tmp_path + '" -i "' + audio_path + '" -codec copy -shortest "' + output_path + '"' + subprocess.call(cmd, shell=True) + os.remove(video_tmp_path) # remove the template video + + temp_out = 'temp_video.avi' + write_video_with_audio(tmp_audio_path, temp_out, 'pred_') + # convert to mp4 + cmd = ("ffmpeg -i " + + temp_out + " -strict -2 " + + str(out_path) + ) + subprocess.call(cmd, shell=True) + + if os.path.exists(tmp_audio_path): + os.remove(tmp_audio_path) + if os.path.exists(temp_out): + os.remove(temp_out) + if os.path.exists(f'shorter_input.{extension_name}'): + os.remove(f'shorter_input.{extension_name}') + if not opt.save_intermediates: + _img_paths = list(map(lambda x: str(x), list(Path(save_root).glob('*.jpg')))) + for i in tqdm(range(len(_img_paths)), desc='deleting intermediate images'): + os.remove(_img_paths[i]) + + print('Finish!') + + return out_path + + +def clean_folder(folder): + for filename in os.listdir(folder): + file_path = os.path.join(folder, filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print('Failed to delete %s. Reason: %s' % (file_path, e)) \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/requirements.txt b/LiveSpeechPortraits/source_code/requirements.txt new file mode 100644 index 00000000..5c988b64 --- /dev/null +++ b/LiveSpeechPortraits/source_code/requirements.txt @@ -0,0 +1,9 @@ +tqdm +librosa==0.7.0 +scikit_image +opencv_python==4.4.0.40 +scipy +dominate +albumentations==0.5.2 +numpy +beautifulsoup4 \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/util/flow_viz.py b/LiveSpeechPortraits/source_code/util/flow_viz.py new file mode 100644 index 00000000..dcee65e8 --- /dev/null +++ b/LiveSpeechPortraits/source_code/util/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/LiveSpeechPortraits/source_code/util/get_data.py b/LiveSpeechPortraits/source_code/util/get_data.py new file mode 100644 index 00000000..97edc3ce --- /dev/null +++ b/LiveSpeechPortraits/source_code/util/get_data.py @@ -0,0 +1,110 @@ +from __future__ import print_function +import os +import tarfile +import requests +from warnings import warn +from zipfile import ZipFile +from bs4 import BeautifulSoup +from os.path import abspath, isdir, join, basename + + +class GetData(object): + """A Python script for downloading CycleGAN or pix2pix datasets. + + Parameters: + technique (str) -- One of: 'cyclegan' or 'pix2pix'. + verbose (bool) -- If True, print additional information. + + Examples: + >>> from util.get_data import GetData + >>> gd = GetData(technique='cyclegan') + >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. + + Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' + and 'scripts/download_cyclegan_model.sh'. + """ + + def __init__(self, technique='cyclegan', verbose=True): + url_dict = { + 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', + 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' + } + self.url = url_dict.get(technique.lower()) + self._verbose = verbose + + def _print(self, text): + if self._verbose: + print(text) + + @staticmethod + def _get_options(r): + soup = BeautifulSoup(r.text, 'lxml') + options = [h.text for h in soup.find_all('a', href=True) + if h.text.endswith(('.zip', 'tar.gz'))] + return options + + def _present_options(self): + r = requests.get(self.url) + options = self._get_options(r) + print('Options:\n') + for i, o in enumerate(options): + print("{0}: {1}".format(i, o)) + choice = input("\nPlease enter the number of the " + "dataset above you wish to download:") + return options[int(choice)] + + def _download_data(self, dataset_url, save_path): + if not isdir(save_path): + os.makedirs(save_path) + + base = basename(dataset_url) + temp_save_path = join(save_path, base) + + with open(temp_save_path, "wb") as f: + r = requests.get(dataset_url) + f.write(r.content) + + if base.endswith('.tar.gz'): + obj = tarfile.open(temp_save_path) + elif base.endswith('.zip'): + obj = ZipFile(temp_save_path, 'r') + else: + raise ValueError("Unknown File Type: {0}.".format(base)) + + self._print("Unpacking Data...") + obj.extractall(save_path) + obj.close() + os.remove(temp_save_path) + + def get(self, save_path, dataset=None): + """ + + Download a dataset. + + Parameters: + save_path (str) -- A directory to save the data to. + dataset (str) -- (optional). A specific dataset to download. + Note: this must include the file extension. + If None, options will be presented for you + to choose from. + + Returns: + save_path_full (str) -- the absolute path to the downloaded data. + + """ + if dataset is None: + selected_dataset = self._present_options() + else: + selected_dataset = dataset + + save_path_full = join(save_path, selected_dataset.split('.')[0]) + + if isdir(save_path_full): + warn("\n'{0}' already exists. Voiding Download.".format( + save_path_full)) + else: + self._print('Downloading Data...') + url = "{0}/{1}".format(self.url, selected_dataset) + self._download_data(url, save_path=save_path) + + return abspath(save_path_full) diff --git a/LiveSpeechPortraits/source_code/util/html.py b/LiveSpeechPortraits/source_code/util/html.py new file mode 100644 index 00000000..10f2fbdc --- /dev/null +++ b/LiveSpeechPortraits/source_code/util/html.py @@ -0,0 +1,67 @@ +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, reflesh=0): + self.title = title + self.web_dir = web_dir + self.img_dir = os.path.join(self.web_dir, 'images') + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + # print(self.img_dir) + + self.doc = dominate.document(title=title) + if reflesh > 0: + with self.doc.head: + meta(http_equiv="reflesh", content=str(reflesh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=400, height=0): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + if height != 0: + img(style="width:%dpx;height:%dpx" % (width, height), src=os.path.join('images', im)) + else: + img(style="width:%dpx" % (width), src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.jpg' % n) + txts.append('text_%d' % n) + links.append('image_%d.jpg' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/LiveSpeechPortraits/source_code/util/image_pool.py b/LiveSpeechPortraits/source_code/util/image_pool.py new file mode 100644 index 00000000..152ef5be --- /dev/null +++ b/LiveSpeechPortraits/source_code/util/image_pool.py @@ -0,0 +1,32 @@ +import random +import numpy as np +import torch +from torch.autograd import Variable +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images.data: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size-1) + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = Variable(torch.cat(return_images, 0)) + return return_images diff --git a/LiveSpeechPortraits/source_code/util/util.py b/LiveSpeechPortraits/source_code/util/util.py new file mode 100644 index 00000000..6bd1dabb --- /dev/null +++ b/LiveSpeechPortraits/source_code/util/util.py @@ -0,0 +1,93 @@ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import inspect, re +import numpy as np +import os +import collections +from PIL import Image +import cv2 +from collections import OrderedDict + +from . import flow_viz + + + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8, normalize=True): + if isinstance(image_tensor, list): + image_numpy = [] + for i in range(len(image_tensor)): + image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) + return image_numpy + + if isinstance(image_tensor, torch.autograd.Variable): + image_tensor = image_tensor.data + if len(image_tensor.size()) == 5: + image_tensor = image_tensor[0, -1] + if len(image_tensor.size()) == 4: + image_tensor = image_tensor[0] + image_tensor = image_tensor[:3] + image_numpy = image_tensor.cpu().float().numpy() + if normalize: + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + #image_numpy = (np.transpose(image_numpy, (1, 2, 0)) * std + mean) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1: + image_numpy = image_numpy[:,:,0] + return image_numpy.astype(imtype) + + +def tensor2flow(flo, imtype=np.uint8): + flo = flo[0].permute(1,2,0).cpu().detach().numpy() + flo = flow_viz.flow_to_image(flo) + return flo + + +def add_dummy_to_tensor(tensors, add_size=0): + if add_size == 0 or tensors is None: return tensors + if isinstance(tensors, list): + return [add_dummy_to_tensor(tensor, add_size) for tensor in tensors] + + if isinstance(tensors, torch.Tensor): + dummy = torch.zeros_like(tensors)[:add_size] + tensors = torch.cat([dummy, tensors]) + return tensors + +def remove_dummy_from_tensor(tensors, remove_size=0): + if remove_size == 0 or tensors is None: return tensors + if isinstance(tensors, list): + return [remove_dummy_from_tensor(tensor, remove_size) for tensor in tensors] + + if isinstance(tensors, torch.Tensor): + tensors = tensors[remove_size:] + return tensors + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + +def print_numpy(x, val=True, shp=False): + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + diff --git a/LiveSpeechPortraits/source_code/util/visualizer.py b/LiveSpeechPortraits/source_code/util/visualizer.py new file mode 100644 index 00000000..fd25a6eb --- /dev/null +++ b/LiveSpeechPortraits/source_code/util/visualizer.py @@ -0,0 +1,149 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import numpy as np +import os +import time +from . import util +from . import html +import scipy.misc +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x + +class Visualizer(): + def __init__(self, opt): + self.opt = opt + self.tf_log = opt.tf_log + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + if opt.isTrain: + if self.tf_log: + from torch.utils.tensorboard import SummaryWriter + # import tensorflow as tf + # self.tf = tf + self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') + # self.writer = tf.summary.FileWriter(self.log_dir) + self.writer = SummaryWriter(self.log_dir, flush_secs=1) + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, step): +# if self.tf_log: # show images in tensorboard output +# img_summaries = [] +# for label, image_numpy in visuals.items(): +# # Write the image to a string +# try: +# s = StringIO() +# except: +# s = BytesIO() +# scipy.misc.toimage(image_numpy).save(s, format="jpeg") +# # Create an Image object +# img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) +# # Create a Summary value +# img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) +# +# # Create and write Summary +# summary = self.tf.Summary(value=img_summaries) +# self.writer.add_summary(summary, step) + + if self.use_html: # save images to a html file + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) + util.save_image(image_numpy[i], img_path) + else: + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) + ims.append(img_path) + txts.append(label+str(i)) + links.append(img_path) + else: + img_path = 'epoch%.3d_%s.jpg' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + if len(ims) < 5: + webpage.add_images(ims, txts, links, width=self.win_size) + else: + num = int(round(len(ims)/2.0)) + webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) + webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) + webpage.save() + + # errors: dictionary of error labels and values + def plot_current_errors(self, errors, step): + if self.tf_log: + for tag, value in errors.items(): +# summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) +# self.writer.add_summary(summary, step) + self.writer.add_scalar(tag, value, step) + + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, t): + message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) + for k, v in sorted(errors.items()): + if v != 0: + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + + # save image to the disk + def save_images(self, image_dir, visuals, image_path, webpage=None): + dirname = os.path.basename(os.path.dirname(image_path[0])) + image_dir = os.path.join(image_dir, dirname) + util.mkdir(image_dir) + name = image_path +# name = os.path.basename(image_path[0]) +# name = os.path.splitext(name)[0] + + if webpage is not None: + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + save_ext = 'jpg' + image_name = '%s_%s.%s' % (label, name, save_ext) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + if webpage is not None: + ims.append(image_name) + txts.append(label) + links.append(image_name) + if webpage is not None: + webpage.add_images(ims, txts, links, width=self.win_size) + + def vis_print(self, message): + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + diff --git a/LiveSpeechPortraits/source_code/video_eval.py b/LiveSpeechPortraits/source_code/video_eval.py new file mode 100644 index 00000000..df953ce2 --- /dev/null +++ b/LiveSpeechPortraits/source_code/video_eval.py @@ -0,0 +1,152 @@ +import cv2 +import os +import numpy as np +import torch +from tqdm import tqdm +import math +from scipy.linalg import sqrtm +from PIL import Image +from torchvision import models, transforms +from pytorch_fid import fid_score # 新增引入pytorch-fid + +# 获取视频信息:帧数与持续时间 +def get_video_info(video_path): + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise Exception(f"Cannot open video: {video_path}") + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = cap.get(cv2.CAP_PROP_FPS) + duration = frame_count / fps if fps > 0 else 0 + cap.release() + return frame_count, duration + +# 根据目标FPS进行视频帧率调整 +def adjust_fps(input_video, output_video, target_fps, target_frame_count=None): + print(f"Adjusting FPS for {input_video} to {target_fps} FPS...") + cap = cv2.VideoCapture(input_video) + if not cap.isOpened(): + raise Exception(f"Cannot open video: {input_video}") + + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + original_fps = cap.get(cv2.CAP_PROP_FPS) + fourcc = cv2.VideoWriter_fourcc(*'XVID') + + out = cv2.VideoWriter(output_video, fourcc, target_fps, (width, height)) + frame_interval = int(original_fps / target_fps) if target_fps < original_fps else 1 + + frame_count = 0 + extracted_frame_count = 0 + last_frame = None + + while True: + ret, frame = cap.read() + if not ret: + print(f"Warning: Failed to read frame {frame_count}. Using last valid frame.") + if last_frame is not None and (target_frame_count is None or extracted_frame_count < target_frame_count): + out.write(last_frame) + extracted_frame_count += 1 + if target_frame_count is not None and extracted_frame_count >= target_frame_count: + break + frame_count += 1 + continue + + if frame_count % frame_interval == 0: + out.write(frame) + extracted_frame_count += 1 + last_frame = frame + + if target_frame_count is not None and extracted_frame_count >= target_frame_count: + break + + frame_count += 1 + + # 如果需要,使用最后一帧填补 + while target_frame_count is not None and extracted_frame_count < target_frame_count: + if last_frame is not None: + out.write(last_frame) + extracted_frame_count += 1 + else: + print("Error: No valid frames available to pad the output.") + + cap.release() + out.release() + print(f"FPS adjustment completed: {input_video} -> {output_video}. Extracted {extracted_frame_count} frames.") + return extracted_frame_count + +# 将视频转为帧 +def video_to_frames(video_path, output_dir, target_frame_count=None): + if os.path.exists(output_dir) and os.listdir(output_dir): + print(f"Frames already exist in {output_dir}, skipping extraction.") + return len(os.listdir(output_dir)) + + print(f"Extracting frames from {video_path}...") + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise Exception(f"Cannot open video: {video_path}") + os.makedirs(output_dir, exist_ok=True) + frame_idx = 0 + extracted_frame_count = 0 + while True: + ret, frame = cap.read() + if not ret: + break + frame_path = os.path.join(output_dir, f"frame_{frame_idx:04d}.png") + cv2.imwrite(frame_path, frame) + extracted_frame_count += 1 + frame_idx += 1 + if target_frame_count is not None and extracted_frame_count >= target_frame_count: + break + cap.release() + print(f"Frame extraction completed. Total frames: {extracted_frame_count}") + return extracted_frame_count + +# 这里原本有InceptionV3FeatureExtractor和preprocess_image函数,现在不再需要。 + +# 使用pytorch-fid计算FID +def compute_fid(gt_frames_dir, gen_frames_dir): + # 使用pytorch_fid提供的calculate_fid_given_paths计算FID + fid = fid_score.calculate_fid_given_paths( + [gt_frames_dir, gen_frames_dir], + batch_size=50, # 可根据情况调整 + device='cuda:0' if torch.cuda.is_available() else 'cpu', + dims=2048 # 与Inception V3特征维度一致 + ) + return fid + +# 主流程,只保留FID计算 +def compute_fid_for_videos(gt_video, gen_video, output_dir, target_fps=30): + # 获取视频信息 + gt_frame_count, gt_duration = get_video_info(gt_video) + gen_frame_count, gen_duration = get_video_info(gen_video) + + target_frame_count = min(gt_frame_count, gen_frame_count) + target_fps_gt = target_frame_count / gt_duration if gt_duration > 0 else target_fps + target_fps_gen = target_frame_count / gen_duration if gen_duration > 0 else target_fps + + # 调整视频帧数和FPS + adjusted_gt_video = os.path.join(output_dir, "adjusted_gt_video.avi") + gt_frame_count = adjust_fps(gt_video, adjusted_gt_video, target_fps_gt, target_frame_count) + + adjusted_gen_video = os.path.join(output_dir, "adjusted_gen_video.avi") + gen_frame_count = adjust_fps(gen_video, adjusted_gen_video, target_fps_gen, target_frame_count) + + # 提取帧 + gt_frames_dir = os.path.join(output_dir, "ground_truth") + gt_frame_count = video_to_frames(adjusted_gt_video, gt_frames_dir, target_frame_count) + + gen_frames_dir = os.path.join(output_dir, "generated") + gen_frame_count = video_to_frames(adjusted_gen_video, gen_frames_dir, target_frame_count) + + # 使用pytorch_fid计算FID + fid = compute_fid(gt_frames_dir, gen_frames_dir) + return fid + +if __name__ == "__main__": + gt_video_path = "data\Input\Obama01.mp4" + gen_video_path = "results\Obama1\Obama01\Obama01.avi" + output_dir = "./evaluation_output" + os.makedirs(output_dir, exist_ok=True) + + fid_value = compute_fid_for_videos(gt_video_path, gen_video_path, output_dir, target_fps=30) + print(f"FID: {fid_value}")