diff --git a/HuaWeiExperiment/.idea/.gitignore b/HuaWeiExperiment/.idea/.gitignore new file mode 100644 index 00000000..359bb530 --- /dev/null +++ b/HuaWeiExperiment/.idea/.gitignore @@ -0,0 +1,3 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml diff --git a/HuaWeiExperiment/.idea/HuaWeiExperiment.iml b/HuaWeiExperiment/.idea/HuaWeiExperiment.iml new file mode 100644 index 00000000..d0876a78 --- /dev/null +++ b/HuaWeiExperiment/.idea/HuaWeiExperiment.iml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/HuaWeiExperiment/.idea/inspectionProfiles/profiles_settings.xml b/HuaWeiExperiment/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 00000000..105ce2da --- /dev/null +++ b/HuaWeiExperiment/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/HuaWeiExperiment/.idea/misc.xml b/HuaWeiExperiment/.idea/misc.xml new file mode 100644 index 00000000..dbb7ab3a --- /dev/null +++ b/HuaWeiExperiment/.idea/misc.xml @@ -0,0 +1,4 @@ + + + + \ No newline at end of file diff --git a/HuaWeiExperiment/.idea/modules.xml b/HuaWeiExperiment/.idea/modules.xml new file mode 100644 index 00000000..76bb2387 --- /dev/null +++ b/HuaWeiExperiment/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/HuaWeiExperiment/ReadMe.md b/HuaWeiExperiment/ReadMe.md new file mode 100644 index 00000000..b8227097 --- /dev/null +++ b/HuaWeiExperiment/ReadMe.md @@ -0,0 +1,9 @@ +具体细节如环境配置、数据处理、模型训练、评估等,请参见文件“语音识别-华为实验-wavenet.docx”,这里仅简单介绍我所做的工作。 + +我选择的项目是wavenet。 + +因为数据集太大,所以没有在此处放数据集,如果需要,请在https://keithito.com/LJ-Speech-Dataset处(2.6G)下载,并且依照文档进行预处理(预处理结束后需要约20多G空间)。 + +因为我的电脑过于拉胯,迫于硬件限制,我只能使用CPU进行训练,同时为了节约时间,我将训练的epoch次数设置为1(原先为2000),尽管如此,训练模型依旧花费了约4个半小时,并且因为epoch过小,所以结果很差,敬请见谅。 + +实验结果在\wavenet\saveAudio内,其中形如xxx_gen.wav为生成的音频,而xxx_ref.wav为参考的音频。 \ No newline at end of file diff --git a/HuaWeiExperiment/main.py b/HuaWeiExperiment/main.py new file mode 100644 index 00000000..90ffb6fd --- /dev/null +++ b/HuaWeiExperiment/main.py @@ -0,0 +1,16 @@ +# 这是一个示例 Python 脚本。 + +# 按 Shift+F10 执行或将其替换为您的代码。 +# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 + + +def print_hi(name): + # 在下面的代码行中使用断点来调试脚本。 + print(f'Hi, {name}') # 按 Ctrl+F8 切换断点。 + + +# 按间距中的绿色按钮以运行脚本。 +if __name__ == '__main__': + print_hi('PyCharm') + +# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助 diff --git a/HuaWeiExperiment/train.log b/HuaWeiExperiment/train.log new file mode 100644 index 00000000..91ca6b61 --- /dev/null +++ b/HuaWeiExperiment/train.log @@ -0,0 +1 @@ +E:\anaconda3\envs\HuaWeiExperiment\python.exe: can't open file 'E:\python\pythonProjects\HuaWeiExperiment\train.py': [Errno 2] No such file or directory diff --git a/HuaWeiExperiment/wavenet/README.md b/HuaWeiExperiment/wavenet/README.md new file mode 100644 index 00000000..ddb963ba --- /dev/null +++ b/HuaWeiExperiment/wavenet/README.md @@ -0,0 +1,305 @@ +# Contents + +- [WaveNet Description](#WaveNet-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Environment Requirements](#environment-requirements) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Evaluation Process](#evaluation-process) + - [Convert Process](#convert-process) +- [Model Description](#model-description) + - [Performance](#performance) + - [Training Performance](#training-performance) + - [Inference Performance](#inference-performance) +- [ModelZoo Homepage](#modelzoo-homepage) + +# [WaveNet Description](#contents) + +WaveNet is a deep neural network for generating raw audio waveforms. The model is fully probabilistic and autoregressive, with the predictive distribution for each audio sample conditioned on all previous ones. We support training and evaluation on Ascend, GPU, and CPU. + +[Paper](https://arxiv.org/pdf/1609.03499.pdf): Oord A, Dieleman S, Zen H, et al. Wavenet: A generative model for raw audio. + +# [Model Architecture](#contents) + +The current model consists of a pre-convolution layer, followed by several residual block which has residual and skip connection with gated activation units. +Finally, post convolution layers are added to predict the distribution. + +# [Dataset](#contents) + +In the following sections, we will introduce how to run the scripts using the related dataset below. + +Dataset used: [The LJ Speech Dataset]() + +- Dataset size:2.6G +- Data format:audio clips(13100) and transcription + +- The dataset structure is as follows: + + ```path + . + └── LJSpeech-1.1 + ├─ wavs //audio clips files + └─ metadata.csv //transcripts + ``` + +# [Environment Requirements](#contents) + +- Hardware(Ascend/GPU/CPU) + - Prepare hardware environment with Ascend/GPU/CPU processor. +- Framework + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorials/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/en/master/api_python/mindspore.html) + +# [Script Description](#contents) + +## [Script and Sample Code](#contents) + +**Note that some of the scripts described below are not included our code**. These scripts should first be download them from [r9y9](https://github.com/r9y9/wavenet_vocoder) and added into this project. + +```path +. +├── audio + └──wavenet + ├── scripts + │ ├──run_distribute_train_ascend.sh // launch distributed training with Ascend platform + │ ├──run_distribute_train_gpu.sh // launch distributed training with GPU platform + │ ├──run_eval_ascend.sh // launch evaluation with Ascend platform + │ ├──run_eval_gpu.sh // launch evaluation with GPU platform + │ ├──run_eval_cpu.sh // launch evaluation with CPU platform + │ ├──run_standalone_train_ascend.sh // launch standalone training with Ascend platform + │ ├──run_standalone_train_gpu.sh // launch standalone training with GPU platform + │ └──run_standalone_train_cpu.sh // launch standalone training with CPU platform + ├── datasets // Process audio files for generating train/evaluate data + ├── egs // Note the egs folder should be downloaded from the above link + ├── utils // Note the utils folder should be downloaded from the above link + ├── audio.py // Audio utils. Note this script should be downloaded from the above link + ├── compute-meanvar-stats.py // Compute mean-variance Normalization stats. Note this script should be downloaded from the above link + ├── evaluate.py // Evaluation + ├── export.py // Convert mindspore model to air/mindir model + ├── hparams.py // Hyper-parameter configuration. Note this script should be downloaded from the above link + ├── mksubset.py // Make subset of dataset. Note this script should be downloaded from the above link + ├── preprocess.py // Preprocess dataset. Note this script should be downloaded from the above link + ├── preprocess_normalize.py // Perform meanvar Normalization to preprocessed features. Note this script should be downloaded from the above link + ├── README.md // Descriptions about WaveNet + ├── train.py // Training scripts + ├── train_pytorch.py // Note this script should be downloaded from the above link. The initial name of this script is train.py in the project from the link + ├── src + │ ├──__init__.py + │ ├──dataset.py // Generate dataloader and data processing entry + │ ├──callback.py // Callbacks to monitor the training + │ ├──lr_generator.py // Learning rate generator + │ └──loss.py // Loss function definition + └── wavenet_vocoder + ├──__init__.py + ├──conv.py // Extended 1D convolution + ├──mixture.py // Loss function for training and sample function for testing + ├──modules.py // Modules for Wavenet construction + ├──upsample.py // Upsample layer definition + ├──util.py // Utils. Note this script should be downloaded from the above link + ├──wavenet.py // WaveNet networks + └──tfcompat // Note this script should be downloaded from the above link + ├──__init__.py + └──hparam.py // Param management tools +``` + +## [Script Parameters](#contents) + +### Training + +```text +usage: train.py [--data_path DATA_PATH] [--preset PRESET] + [--checkpoint_dir CHECKPOINT_DIR] [--checkpoint CHECKPOINT] + [--speaker_id SPEAKER_ID] [--platform PLATFORM] + [--mode_name MODE] [--is_distributed IS_DISTRIBUTED] + +options: + --data_path dataset path + --preset path of preset parameters (json) + --checkpoint_dir directory of saving model checkpoints + --checkpoint pre-trained ckpt path, default is "./checkpoints" + --speaker_id specific speaker of data in case for multi-speaker datasets, not used currently + --platform specify platform to be used, defeault is "GPU" + --mode_name specify graph mode, default is "GRAPH" + --is_distributed whether distributed training or not + +``` + +### Evaluation + +```text +usage: evaluate.py [--data_path DATA_PATH] [--preset PRESET] + [--pretrain_ckpt PRETRAIN_CKPT] [--is_numpy] + [--output_path OUTPUT_PATH] [--speaker_id SPEAKER_ID] + [--platform PLATFORM] +options: + --data_path dataset path + --preset path of preset parameters (json) + --pretrain_ckpt pre-trained ckpt path + --is_numpy whether using numpy for inference or not + --output_path path to save synthesized audio + --speaker_id specific speaker of data in case for multi-speaker datasets, not used currently + --platform specify platform to be used, defeault is "GPU" +``` + +More parameters for training and evaluation can be set in file `hparams.py`. + +## [Training Process](#contents) + +Before your first training, some dependency scripts should be downloaded and placed in correct directory as described in [Script and Sample Code]. +After that, raw data should be pre-processed by using the scripts in `egs`. The directory of egs is as follows: + +```path +. +├── egs + ├──gaussian + │ ├──conf + │ │ ├──gaussian_wavenet.json + │ │ └──gaussian_wavenet_demo.json + │ └──run.sh + ├──mol + │ ├──conf + │ │ ├──mol_wavenet.json + │ │ └──mol_wavenet_demo.json + │ └──run.sh + ├──mulaw256 + │ ├──conf + │ │ ├──mulaw_wavenet.json + │ │ └──mulaw_wavenet_demo.json + │ └──run.sh + └──README.md +``` + +In this project, three different losses are implemented to train the network: + +- mulaw256: categorical output distribution. The input is 8-bit mulaw quantized waveform. +- mol: discretized mix logistic loss. The input is 16-bit raw audio. +- gaussian: mix gaussian loss. The input is 16-bit raw audio. + +The three folder gaussian, mol, mulaw is used to generate corresponding training data respectively. For example, To generate the training data for +mix gaussian loss, you should first modify the `run.sh` in line 28. Change `conf/gaussian_wavenet_demo.json` to +`conf/gaussian_wavenet.json`. We use the default parameter in `gaussian_wavenet.json`. By this setting, data will be generated to adapt to mix gaussian loss and +some parameters in `hparams.py` will be covered by that in `gaussian_wavenet.json`. You can also define your own hyper-parameter json here. After the modification, +The following command can be ran for data generation. Note that if you want to change values of some parameters, you may need to modify in `gaussian_wavenet.json` instead of `hparams.py` since `gaussian_wavenet.json` may cover that in`hparams.py`. + +```bash +bash run.sh --stage 0 --stop-stage 0 --db-root /path_to_dataset/LJSpeech-1.1/wavs +bash run.sh --stage 1 --stop-stage 1 +``` + +After the processing, the directory of gaussian will be as follows: + +```path +. +├── gaussian + ├──conf + ├──data + ├──exp + └──dump + └──lj + └──logmelspectrogram + ├──org + └──norm + ├──train_no_dev + ├──dev + └──eval +``` + +The train_no_dev folder contains the final training data. For mol and gaussian, the process is the same. When the training data is prepared, +you can run the following command to train the network: + +```bash +Standalone training +Ascend: +bash ./scripts/run_standalone_train_ascend.sh train.py [/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/] [/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json] [path_to_save_ckpt] [DEVICE_ID] + +GPU: +bash ./scripts/run_standalone_train_gpu.sh [DEVICE_ID] [/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/] [/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json] [path_to_save_ckpt] + +CPU: +bash ./scripts/run_standalone_train_cpu.sh [/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/] [/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json] [path_to_save_ckpt] + +Distributed training +Ascend: +bash ./scripts/run_distributed_train_ascend.sh [/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/] [/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json] [path_to_save_ckpt] [path_to_hccl_config_file] [RANK_SIZE] [FIRST_DEVICE_ID] + +GPU(8p): +bash ./scripts/run_distribute_train_gpu.sh [/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/] [/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json] [path_to_save_ckpt] +``` + +## [Evaluation Process](#contents) + +WaveNet has a process of auto-regression and this process currently cannot be run in Graph mode(place the auto-regression into `construct`). Therefore, we implement the process in a common function. Here, we provide two kinds of ways to realize the function: using Numpy or using MindSpore ops. One can set `is_numpy` to determine which mode is used. We recommend using numpy on GPU since it is much faster than using MindSpore ops. This is because the auto-regression process only calls some simple operation like Matmul and Bias_add. Unlike Graph mode, there will exist some fixed cost each step and this leads to a lower speed. For more information, please refer to +this [link](https://bbs.huaweicloud.com/forum/thread-94852-1-1.html) + +```bash +Evaluation +Ascend (using mindspore): +bash ./scripts/run_eval_ascend.sh [DEVICE_ID] evaluate.py [/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval] [/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json] [path_to_load_ckpt] [path_to_save_audio] + +Ascend(using numpy): +bash ./scripts/run_eval_ascend.sh [DEVICE_ID] evaluate.py [/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval] [/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json] [path_to_load_ckpt] is_numpy [path_to_save_audio] + +GPU (using mindspore): +bash ./scripts/run_eval_gpu.sh [CUDA_DEVICE_ID] [/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval] [/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json] [path_to_load_ckpt] [path_to_save_audio] + +GPU (using numpy): +bash ./scripts/run_eval_gpu.sh [CUDA_DEVICE_ID] [/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval] [/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json] [path_to_load_ckpt] is_numpy [path_to_save_audio] + +CPU: +bash ./scripts/run_eval_cpu.sh [/path_to_egs/egs/gaussian/dump/lj/logmelspectrogram/norm/eval] [/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json] [path_to_load_ckpt] [is_numpy] [path_to_save_audio] +``` + +## [Convert Process](#contents) + +```bash +Ascend: +python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_dump_hparams --pretrain_ckpt=path_to_load_ckpt --platform=Ascend + +GPU: +python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_dump_hparams --pretrain_ckpt=path_to_load_ckpt + +CPU: +python export.py --preset=/path_to_egs/egs/gaussian/conf/gaussian_wavenet.json --checkpoint_dir=path_to_dump_hparams --pretrain_ckpt=path_to_load_ckpt --platform=CPU +``` + +# [Model Description](#contents) + +## [Performance](#contents) + +### Training Performance + +| Parameters | GPU | Ascend | +| -------------------- | ------------------------------------------------------------ | :----------------------------------------------------------- | +| Resource | NV SMX2 V100-32G | Ascend 910 | +| uploaded Date | 01/14/2021 (month/day/year) | 09/27/2021 (month/day/year) | +| MindSpore Version | 1.0.0 | 1.3.0 | +| Dataset | LJSpeech-1.1 | LJSpeech-1.1 | +| Training Parameters | 1p, epoch=600(max), steps=1635 * epoch, batch_size = 8, lr=1e-3 | 8p, epoch=2000(max), steps=1635*epoch, batch_size =24, lr=2e-3 | +| Optimizer | Adam | Adam | +| Loss Function | SoftmaxCrossEntropyWithLogits/discretized_mix_logistic/mix_gaussian | SoftmaxCrossEntropyWithLogits/mix_gaussian | +| Loss | around 2.0(mulaw256)/around 4.5(mol)/around -6.0(gaussian) | around -5.0(gaussian) | +| Speed | 1p 1.467s/step | | +| Total time: training | 1p(mol/gaussian): around 4 days; 2p(mulaw256):around 1 week | 8p: around 1 week | +| Checkpoint | 59.79MM/54.87M/54.83M (.ckpt file) | 42.73M (.ckpt file) | +| Scripts | [WaveNet script](https://gitee.com/mindspore/models/tree/master/research/audio/wavenet) | [WaveNet script](https://gitee.com/mindspore/models/tree/master/research/audio/wavenet) | + +### Inference Performance + +Audio samples will be demonstrated online soon. + +# [ModelZoo Homepage](#contents) + + Please check the official [homepage](https://gitee.com/mindspore/models). + +## FAQ + +Please refer to [ModelZoo FAQ](https://gitee.com/mindspore/models#FAQ) to get some common FAQ. + +- **Q: What third-party packages are required and how to install them?** + + **A**: nnmnkwii, librosa(preferably 0.4.0) and tqdm are needed in order to run wavenet. Instructions on how to install these packages can be found at [nnmnkwii](https://github.com/r9y9/nnmnkwii). On windows systems, specifically, installing nnmnkwii requires pysptk installed, which then requires Microsoft Visual C++ 14.0 installed previously. diff --git a/HuaWeiExperiment/wavenet/WaveNet.mindir b/HuaWeiExperiment/wavenet/WaveNet.mindir new file mode 100644 index 00000000..58adb79e Binary files /dev/null and b/HuaWeiExperiment/wavenet/WaveNet.mindir differ diff --git a/HuaWeiExperiment/wavenet/audio.py b/HuaWeiExperiment/wavenet/audio.py new file mode 100644 index 00000000..663516f2 --- /dev/null +++ b/HuaWeiExperiment/wavenet/audio.py @@ -0,0 +1,173 @@ +import librosa +import librosa.filters +import numpy as np +from hparams import hparams +from scipy.io import wavfile +from nnmnkwii import preprocessing as P + + +def low_cut_filter(x, fs, cutoff=70): + """APPLY LOW CUT FILTER. + + https://github.com/kan-bayashi/PytorchWaveNetVocoder + + Args: + x (ndarray): Waveform sequence. + fs (int): Sampling frequency. + cutoff (float): Cutoff frequency of low cut filter. + Return: + ndarray: Low cut filtered waveform sequence. + """ + nyquist = fs // 2 + norm_cutoff = cutoff / nyquist + from scipy.signal import firwin, lfilter + + # low cut filter + fil = firwin(255, norm_cutoff, pass_zero=False) + lcf_x = lfilter(fil, 1, x) + + return lcf_x + + +def load_wav(path): + sr, x = wavfile.read(path) + signed_int16_max = 2**15 + if x.dtype == np.int16: + x = x.astype(np.float32) / signed_int16_max + if sr != hparams.sample_rate: + x = librosa.resample(x, sr, hparams.sample_rate) + x = np.clip(x, -1.0, 1.0) + return x + + +def save_wav(wav, path): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + wavfile.write(path, hparams.sample_rate, wav.astype(np.int16)) + + +def trim(quantized): + start, end = start_and_end_indices(quantized, hparams.silence_threshold) + return quantized[start:end] + + +def preemphasis(x, coef=0.85): + return P.preemphasis(x, coef) + + +def inv_preemphasis(x, coef=0.85): + return P.inv_preemphasis(x, coef) + + +def adjust_time_resolution(quantized, mel): + """Adjust time resolution by repeating features + + Args: + quantized (ndarray): (T,) + mel (ndarray): (N, D) + + Returns: + tuple: Tuple of (T,) and (T, D) + """ + assert len(quantized.shape) == 1 + assert len(mel.shape) == 2 + + upsample_factor = quantized.size // mel.shape[0] + mel = np.repeat(mel, upsample_factor, axis=0) + n_pad = quantized.size - mel.shape[0] + if n_pad != 0: + assert n_pad > 0 + mel = np.pad(mel, [(0, n_pad), (0, 0)], mode="constant", constant_values=0) + + # trim + start, end = start_and_end_indices(quantized, hparams.silence_threshold) + + return quantized[start:end], mel[start:end, :] + + +def start_and_end_indices(quantized, silence_threshold=2): + for start in range(quantized.size): + if abs(quantized[start] - 127) > silence_threshold: + break + for end in range(quantized.size - 1, 1, -1): + if abs(quantized[end] - 127) > silence_threshold: + break + + assert abs(quantized[start] - 127) > silence_threshold + assert abs(quantized[end] - 127) > silence_threshold + + return start, end + + +def logmelspectrogram(y, pad_mode="reflect"): + """Same log-melspectrogram computation as espnet + https://github.com/espnet/espnet + from espnet.transform.spectrogram import logmelspectrogram + """ + D = _stft(y, pad_mode=pad_mode) + S = _linear_to_mel(np.abs(D)) + S = np.log10(np.maximum(S, 1e-10)) + return S + + +def get_hop_size(): + hop_size = hparams.hop_size + if hop_size is None: + assert hparams.frame_shift_ms is not None + hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) + return hop_size + + +def get_win_length(): + win_length = hparams.win_length + if win_length < 0: + assert hparams.win_length_ms > 0 + win_length = int(hparams.win_length_ms / 1000 * hparams.sample_rate) + return win_length + + +def _stft(y, pad_mode="constant"): + # use constant padding (defaults to zeros) instead of reflection padding + return librosa.stft(y=y, n_fft=hparams.fft_size, hop_length=get_hop_size(), + win_length=get_win_length(), window=hparams.window, + pad_mode=pad_mode) + + +def pad_lr(x, fsize, fshift): + return (0, fsize) + +# Conversions: + + +_mel_basis = None + + +def _linear_to_mel(spectrogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectrogram) + + +def _build_mel_basis(): + if hparams.fmax is not None: + assert hparams.fmax <= hparams.sample_rate // 2 + return librosa.filters.mel(hparams.sample_rate, hparams.fft_size, + fmin=hparams.fmin, fmax=hparams.fmax, + n_mels=hparams.num_mels) + + +def _amp_to_db(x): + min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + + +def _db_to_amp(x): + return np.power(10.0, x * 0.05) + + +def _normalize(S): + return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1) + + +def _denormalize(S): + return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db diff --git a/HuaWeiExperiment/wavenet/compute-meanvar-stats.py b/HuaWeiExperiment/wavenet/compute-meanvar-stats.py new file mode 100644 index 00000000..50527758 --- /dev/null +++ b/HuaWeiExperiment/wavenet/compute-meanvar-stats.py @@ -0,0 +1,38 @@ +# coding: utf-8 +"""Compute mean-variance normalization stats. + +usage: compute_meanvar_stats.py [options] + +options: + -h, --help Show help message. + --verbose= Verbosity [default: 0]. +""" +from docopt import docopt +import sys +from tqdm import tqdm +import numpy as np +import json + +from sklearn.preprocessing import StandardScaler +import joblib + +if __name__ == "__main__": + args = docopt(__doc__) + list_file = args[""] + out_path = args[""] + verbose = int(args["--verbose"]) + + scaler = StandardScaler() + with open(list_file) as f: + lines = f.readlines() + assert len(lines) > 0 + for path in tqdm(lines): + c = np.load(path.strip()) + scaler.partial_fit(c) + joblib.dump(scaler, out_path) + + if verbose > 0: + print("mean:\n{}".format(scaler.mean_)) + print("var:\n{}".format(scaler.var_)) + + sys.exit(0) diff --git a/HuaWeiExperiment/wavenet/datasets/audio.py b/HuaWeiExperiment/wavenet/datasets/audio.py new file mode 100644 index 00000000..663516f2 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/audio.py @@ -0,0 +1,173 @@ +import librosa +import librosa.filters +import numpy as np +from hparams import hparams +from scipy.io import wavfile +from nnmnkwii import preprocessing as P + + +def low_cut_filter(x, fs, cutoff=70): + """APPLY LOW CUT FILTER. + + https://github.com/kan-bayashi/PytorchWaveNetVocoder + + Args: + x (ndarray): Waveform sequence. + fs (int): Sampling frequency. + cutoff (float): Cutoff frequency of low cut filter. + Return: + ndarray: Low cut filtered waveform sequence. + """ + nyquist = fs // 2 + norm_cutoff = cutoff / nyquist + from scipy.signal import firwin, lfilter + + # low cut filter + fil = firwin(255, norm_cutoff, pass_zero=False) + lcf_x = lfilter(fil, 1, x) + + return lcf_x + + +def load_wav(path): + sr, x = wavfile.read(path) + signed_int16_max = 2**15 + if x.dtype == np.int16: + x = x.astype(np.float32) / signed_int16_max + if sr != hparams.sample_rate: + x = librosa.resample(x, sr, hparams.sample_rate) + x = np.clip(x, -1.0, 1.0) + return x + + +def save_wav(wav, path): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + wavfile.write(path, hparams.sample_rate, wav.astype(np.int16)) + + +def trim(quantized): + start, end = start_and_end_indices(quantized, hparams.silence_threshold) + return quantized[start:end] + + +def preemphasis(x, coef=0.85): + return P.preemphasis(x, coef) + + +def inv_preemphasis(x, coef=0.85): + return P.inv_preemphasis(x, coef) + + +def adjust_time_resolution(quantized, mel): + """Adjust time resolution by repeating features + + Args: + quantized (ndarray): (T,) + mel (ndarray): (N, D) + + Returns: + tuple: Tuple of (T,) and (T, D) + """ + assert len(quantized.shape) == 1 + assert len(mel.shape) == 2 + + upsample_factor = quantized.size // mel.shape[0] + mel = np.repeat(mel, upsample_factor, axis=0) + n_pad = quantized.size - mel.shape[0] + if n_pad != 0: + assert n_pad > 0 + mel = np.pad(mel, [(0, n_pad), (0, 0)], mode="constant", constant_values=0) + + # trim + start, end = start_and_end_indices(quantized, hparams.silence_threshold) + + return quantized[start:end], mel[start:end, :] + + +def start_and_end_indices(quantized, silence_threshold=2): + for start in range(quantized.size): + if abs(quantized[start] - 127) > silence_threshold: + break + for end in range(quantized.size - 1, 1, -1): + if abs(quantized[end] - 127) > silence_threshold: + break + + assert abs(quantized[start] - 127) > silence_threshold + assert abs(quantized[end] - 127) > silence_threshold + + return start, end + + +def logmelspectrogram(y, pad_mode="reflect"): + """Same log-melspectrogram computation as espnet + https://github.com/espnet/espnet + from espnet.transform.spectrogram import logmelspectrogram + """ + D = _stft(y, pad_mode=pad_mode) + S = _linear_to_mel(np.abs(D)) + S = np.log10(np.maximum(S, 1e-10)) + return S + + +def get_hop_size(): + hop_size = hparams.hop_size + if hop_size is None: + assert hparams.frame_shift_ms is not None + hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) + return hop_size + + +def get_win_length(): + win_length = hparams.win_length + if win_length < 0: + assert hparams.win_length_ms > 0 + win_length = int(hparams.win_length_ms / 1000 * hparams.sample_rate) + return win_length + + +def _stft(y, pad_mode="constant"): + # use constant padding (defaults to zeros) instead of reflection padding + return librosa.stft(y=y, n_fft=hparams.fft_size, hop_length=get_hop_size(), + win_length=get_win_length(), window=hparams.window, + pad_mode=pad_mode) + + +def pad_lr(x, fsize, fshift): + return (0, fsize) + +# Conversions: + + +_mel_basis = None + + +def _linear_to_mel(spectrogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectrogram) + + +def _build_mel_basis(): + if hparams.fmax is not None: + assert hparams.fmax <= hparams.sample_rate // 2 + return librosa.filters.mel(hparams.sample_rate, hparams.fft_size, + fmin=hparams.fmin, fmax=hparams.fmax, + n_mels=hparams.num_mels) + + +def _amp_to_db(x): + min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + + +def _db_to_amp(x): + return np.power(10.0, x * 0.05) + + +def _normalize(S): + return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1) + + +def _denormalize(S): + return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db diff --git a/HuaWeiExperiment/wavenet/datasets/train_pytorch.py b/HuaWeiExperiment/wavenet/datasets/train_pytorch.py new file mode 100644 index 00000000..8f8e56e4 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/train_pytorch.py @@ -0,0 +1,1117 @@ +"""Trainining script for WaveNet vocoder + +usage: train.py [options] + +options: + --dump-root= Directory contains preprocessed features. + --checkpoint-dir= Directory where to save model checkpoints [default: checkpoints]. + --hparams= Hyper parameters [default: ]. + --preset= Path of preset parameters (json). + --checkpoint= Restore model from checkpoint path if given. + --restore-parts= Restore part of the model. + --log-event-path= Log event path. + --reset-optimizer Reset optimizer. + --speaker-id= Use specific speaker of data in case for multi-speaker datasets. + -h, --help Show this help message and exit +""" +from docopt import docopt + +import sys + +import os +from os.path import dirname, join, expanduser, exists +from tqdm import tqdm +from datetime import datetime +import random +import json +from glob import glob + +import numpy as np + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +import torch +from torch import nn +from torch.nn import functional as F +from torch import optim +import torch.backends.cudnn as cudnn +from torch.utils import data as data_utils +from torch.utils.data.sampler import Sampler + +import torch.optim.lr_scheduler as lrschedule + + +from nnmnkwii import preprocessing as P +from nnmnkwii.datasets import FileSourceDataset, FileDataSource + +import librosa.display + +from tensorboardX import SummaryWriter +from matplotlib import cm +from warnings import warn + +from wavenet_vocoder import WaveNet +from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_raw, is_scalar_input +from wavenet_vocoder.mixture import discretized_mix_logistic_loss +from wavenet_vocoder.mixture import sample_from_discretized_mix_logistic +from wavenet_vocoder.mixture import mix_gaussian_loss +from wavenet_vocoder.mixture import sample_from_mix_gaussian + +import audio +from hparams import hparams, hparams_debug_string + + +global_step = 0 +global_test_step = 0 +global_epoch = 0 +use_cuda = torch.cuda.is_available() +if use_cuda: + cudnn.benchmark = True + + +def sanity_check(model, c, g): + if model.has_speaker_embedding(): + if g is None: + raise RuntimeError( + "WaveNet expects speaker embedding, but speaker-id is not provided") + else: + if g is not None: + raise RuntimeError( + "WaveNet expects no speaker embedding, but speaker-id is provided") + + if model.local_conditioning_enabled(): + if c is None: + raise RuntimeError("WaveNet expects conditional features, but not given") + else: + if c is not None: + raise RuntimeError("WaveNet expects no conditional features, but given") + + +def maybe_set_epochs_based_on_max_steps(hp, steps_per_epoch): + nepochs = hp.nepochs + max_train_steps = hp.max_train_steps + if max_train_steps is not None: + epochs = int(np.ceil(max_train_steps / steps_per_epoch)) + hp.nepochs = epochs + print("info; Number of epochs is set based on max_train_steps: {}".format(epochs)) + + +def _pad(seq, max_len, constant_values=0): + return np.pad(seq, (0, max_len - len(seq)), + mode='constant', constant_values=constant_values) + + +def _pad_2d(x, max_len, b_pad=0, constant_values=0): + x = np.pad(x, [(b_pad, max_len - len(x) - b_pad), (0, 0)], + mode="constant", constant_values=constant_values) + return x + +# from: https://github.com/keras-team/keras/blob/master/keras/utils/np_utils.py +# to avoid keras dependency + + +def to_categorical(y, num_classes=None, dtype='float32'): + """Converts a class vector (integers) to binary class matrix. + E.g. for use with categorical_crossentropy. + # Arguments + y: class vector to be converted into a matrix + (integers from 0 to num_classes). + num_classes: total number of classes. + dtype: The data type expected by the input, as a string + (`float32`, `float64`, `int32`...) + # Returns + A binary matrix representation of the input. The classes axis + is placed last. + # Example + ```python + # Consider an array of 5 labels out of a set of 3 classes {0, 1, 2}: + > labels + array([0, 2, 1, 2, 0]) + # `to_categorical` converts this into a matrix with as many + # columns as there are classes. The number of rows + # stays the same. + > to_categorical(labels) + array([[ 1., 0., 0.], + [ 0., 0., 1.], + [ 0., 1., 0.], + [ 0., 0., 1.], + [ 1., 0., 0.]], dtype=float32) + ``` + """ + + y = np.array(y, dtype='int') + input_shape = y.shape + if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: + input_shape = tuple(input_shape[:-1]) + y = y.ravel() + if not num_classes: + num_classes = np.max(y) + 1 + n = y.shape[0] + categorical = np.zeros((n, num_classes), dtype=dtype) + categorical[np.arange(n), y] = 1 + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + return categorical + + +# TODO: I know this is too ugly... +class _NPYDataSource(FileDataSource): + def __init__(self, dump_root, col, typ="", speaker_id=None, max_steps=8000, + cin_pad=0, hop_size=256): + self.dump_root = dump_root + self.col = col + self.lengths = [] + self.speaker_id = speaker_id + self.multi_speaker = False + self.speaker_ids = None + self.max_steps = max_steps + self.cin_pad = cin_pad + self.hop_size = hop_size + self.typ = typ + + def collect_files(self): + meta = join(self.dump_root, "train.txt") + if not exists(meta): + paths = sorted(glob(join(self.dump_root, "*-{}.npy".format(self.typ)))) + return paths + + with open(meta, "rb") as f: + lines = f.readlines() + l = lines[0].decode("utf-8").split("|") + assert len(l) == 4 or len(l) == 5 + self.multi_speaker = len(l) == 5 + self.lengths = list( + map(lambda l: int(l.decode("utf-8").split("|")[2]), lines)) + + paths_relative = list(map(lambda l: l.decode("utf-8").split("|")[self.col], lines)) + paths = list(map(lambda f: join(self.dump_root, f), paths_relative)) + + # Exclude small files (assuming lenghts are in frame unit) + # TODO: consider this for multi-speaker + if self.max_steps is not None: + idx = np.array(self.lengths) * self.hop_size > self.max_steps + 2 * self.cin_pad * self.hop_size + if idx.sum() != len(self.lengths): + print("{} short samples are omitted for training.".format(len(self.lengths) - idx.sum())) + self.lengths = list(np.array(self.lengths)[idx]) + paths = list(np.array(paths)[idx]) + + if self.multi_speaker: + speaker_ids = list(map(lambda l: int(l.decode("utf-8").split("|")[-1]), lines)) + self.speaker_ids = speaker_ids + if self.speaker_id is not None: + # Filter by speaker_id + # using multi-speaker dataset as a single speaker dataset + indices = np.array(speaker_ids) == self.speaker_id + paths = list(np.array(paths)[indices]) + self.lengths = list(np.array(self.lengths)[indices]) + # aha, need to cast numpy.int64 to int + self.lengths = list(map(int, self.lengths)) + self.multi_speaker = False + + if self.multi_speaker: + speaker_ids_np = list(np.array(self.speaker_ids)[indices]) + self.speaker_ids = list(map(int, speaker_ids_np)) + assert len(paths) == len(self.speaker_ids) + + return paths + + def collect_features(self, path): + return np.load(path) + + +class RawAudioDataSource(_NPYDataSource): + def __init__(self, dump_root, **kwargs): + super(RawAudioDataSource, self).__init__(dump_root, 0, "wave", **kwargs) + + +class MelSpecDataSource(_NPYDataSource): + def __init__(self, dump_root, **kwargs): + super(MelSpecDataSource, self).__init__(dump_root, 1, "feats", **kwargs) + + +class PartialyRandomizedSimilarTimeLengthSampler(Sampler): + """Partially randomized sampler + + 1. Sort by lengths + 2. Pick a small patch and randomize it + 3. Permutate mini-batches + """ + + def __init__(self, lengths, batch_size=8, batch_group_size=None): + self.lengths, self.sorted_indices = torch.sort(torch.LongTensor(lengths)) + + self.batch_size = batch_size + if batch_group_size is None: + batch_group_size = min(batch_size * 8, len(self.lengths)) + if batch_group_size % batch_size != 0: + batch_group_size -= batch_group_size % batch_size + + self.batch_group_size = batch_group_size + assert batch_group_size % batch_size == 0 + + def __iter__(self): + indices = self.sorted_indices.numpy() + batch_group_size = self.batch_group_size + s, e = 0, 0 + bins = [] + for i in range(len(indices) // batch_group_size): + s = i * batch_group_size + e = s + batch_group_size + group = indices[s:e] + random.shuffle(group) + bins += [group] + + # Permutate batches + random.shuffle(bins) + binned_idx = np.stack(bins).reshape(-1) + + # Handle last elements + s += batch_group_size + if s < len(indices): + last_bin = indices[len(binned_idx):] + random.shuffle(last_bin) + binned_idx = np.concatenate([binned_idx, last_bin]) + + return iter(torch.tensor(binned_idx).long()) + + def __len__(self): + return len(self.sorted_indices) + + +class PyTorchDataset(object): + def __init__(self, X, Mel): + self.X = X + self.Mel = Mel + # alias + self.multi_speaker = X.file_data_source.multi_speaker + + def __getitem__(self, idx): + if self.Mel is None: + mel = None + else: + mel = self.Mel[idx] + + raw_audio = self.X[idx] + if self.multi_speaker: + speaker_id = self.X.file_data_source.speaker_ids[idx] + else: + speaker_id = None + + # (x,c,g) + return raw_audio, mel, speaker_id + + def __len__(self): + return len(self.X) + + +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + batch_size = sequence_length.size(0) + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.cuda() + seq_length_expand = sequence_length.unsqueeze(1) \ + .expand_as(seq_range_expand) + return (seq_range_expand < seq_length_expand).float() + + +# https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/4 +# https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage +class ExponentialMovingAverage(object): + def __init__(self, decay): + self.decay = decay + self.shadow = {} + + def register(self, name, val): + self.shadow[name] = val.clone() + + def update(self, name, x): + assert name in self.shadow + update_delta = self.shadow[name] - x + self.shadow[name] -= (1.0 - self.decay) * update_delta + + +def clone_as_averaged_model(device, model, ema): + assert ema is not None + averaged_model = build_model().to(device) + averaged_model.load_state_dict(model.state_dict()) + for name, param in averaged_model.named_parameters(): + if name in ema.shadow: + param.data = ema.shadow[name].clone() + return averaged_model + + +class MaskedCrossEntropyLoss(nn.Module): + def __init__(self): + super(MaskedCrossEntropyLoss, self).__init__() + self.criterion = nn.CrossEntropyLoss(reduction='none') + + def forward(self, input, target, lengths=None, mask=None, max_len=None): + if lengths is None and mask is None: + raise RuntimeError("Should provide either lengths or mask") + + # (B, T, 1) + if mask is None: + mask = sequence_mask(lengths, max_len).unsqueeze(-1) + + # (B, T, D) + mask_ = mask.expand_as(target) + losses = self.criterion(input, target) + return ((losses * mask_).sum()) / mask_.sum() + + +class DiscretizedMixturelogisticLoss(nn.Module): + def __init__(self): + super(DiscretizedMixturelogisticLoss, self).__init__() + + def forward(self, input, target, lengths=None, mask=None, max_len=None): + if lengths is None and mask is None: + raise RuntimeError("Should provide either lengths or mask") + + # (B, T, 1) + if mask is None: + mask = sequence_mask(lengths, max_len).unsqueeze(-1) + + # (B, T, 1) + mask_ = mask.expand_as(target) + + losses = discretized_mix_logistic_loss( + input, target, num_classes=hparams.quantize_channels, + log_scale_min=hparams.log_scale_min, reduce=False) + assert losses.size() == target.size() + return ((losses * mask_).sum()) / mask_.sum() + + +class MixtureGaussianLoss(nn.Module): + def __init__(self): + super(MixtureGaussianLoss, self).__init__() + + def forward(self, input, target, lengths=None, mask=None, max_len=None): + if lengths is None and mask is None: + raise RuntimeError("Should provide either lengths or mask") + + # (B, T, 1) + if mask is None: + mask = sequence_mask(lengths, max_len).unsqueeze(-1) + + # (B, T, 1) + mask_ = mask.expand_as(target) + + losses = mix_gaussian_loss( + input, target, log_scale_min=hparams.log_scale_min, reduce=False) + assert losses.size() == target.size() + return ((losses * mask_).sum()) / mask_.sum() + + +def ensure_divisible(length, divisible_by=256, lower=True): + if length % divisible_by == 0: + return length + if lower: + return length - length % divisible_by + else: + return length + (divisible_by - length % divisible_by) + + +def assert_ready_for_upsampling(x, c, cin_pad): + assert len(x) == (len(c) - 2 * cin_pad) * audio.get_hop_size() + + +def collate_fn(batch): + """Create batch + + Args: + batch(tuple): List of tuples + - x[0] (ndarray,int) : list of (T,) + - x[1] (ndarray,int) : list of (T, D) + - x[2] (ndarray,int) : list of (1,), speaker id + Returns: + tuple: Tuple of batch + - x (FloatTensor) : Network inputs (B, C, T) + - y (LongTensor) : Network targets (B, T, 1) + """ + + local_conditioning = len(batch[0]) >= 2 and hparams.cin_channels > 0 + global_conditioning = len(batch[0]) >= 3 and hparams.gin_channels > 0 + + if hparams.max_time_sec is not None: + max_time_steps = int(hparams.max_time_sec * hparams.sample_rate) + elif hparams.max_time_steps is not None: + max_time_steps = hparams.max_time_steps + else: + max_time_steps = None + + # Time resolution adjustment + cin_pad = hparams.cin_pad + if local_conditioning: + new_batch = [] + for idx in range(len(batch)): + x, c, g = batch[idx] + if hparams.upsample_conditional_features: + assert_ready_for_upsampling(x, c, cin_pad=0) + if max_time_steps is not None: + max_steps = ensure_divisible(max_time_steps, audio.get_hop_size(), True) + if len(x) > max_steps: + max_time_frames = max_steps // audio.get_hop_size() + s = np.random.randint(cin_pad, len(c) - max_time_frames - cin_pad) + ts = s * audio.get_hop_size() + x = x[ts:ts + audio.get_hop_size() * max_time_frames] + c = c[s - cin_pad:s + max_time_frames + cin_pad, :] + assert_ready_for_upsampling(x, c, cin_pad=cin_pad) + else: + x, c = audio.adjust_time_resolution(x, c) + if max_time_steps is not None and len(x) > max_time_steps: + s = np.random.randint(cin_pad, len(x) - max_time_steps - cin_pad) + x = x[s:s + max_time_steps] + c = c[s - cin_pad:s + max_time_steps + cin_pad, :] + assert len(x) == len(c) + new_batch.append((x, c, g)) + batch = new_batch + else: + new_batch = [] + for idx in range(len(batch)): + x, c, g = batch[idx] + x = audio.trim(x) + if max_time_steps is not None and len(x) > max_time_steps: + s = np.random.randint(0, len(x) - max_time_steps) + if local_conditioning: + x, c = x[s:s + max_time_steps], c[s:s + max_time_steps, :] + else: + x = x[s:s + max_time_steps] + new_batch.append((x, c, g)) + batch = new_batch + + # Lengths + input_lengths = [len(x[0]) for x in batch] + max_input_len = max(input_lengths) + + # (B, T, C) + # pad for time-axis + if is_mulaw_quantize(hparams.input_type): + padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1) + x_batch = np.array([_pad_2d(to_categorical( + x[0], num_classes=hparams.quantize_channels), + max_input_len, 0, padding_value) for x in batch], dtype=np.float32) + else: + x_batch = np.array([_pad_2d(x[0].reshape(-1, 1), max_input_len) + for x in batch], dtype=np.float32) + assert len(x_batch.shape) == 3 + + # (B, T) + if is_mulaw_quantize(hparams.input_type): + padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1) + y_batch = np.array([_pad(x[0], max_input_len, constant_values=padding_value) + for x in batch], dtype=np.int) + else: + y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.float32) + assert len(y_batch.shape) == 2 + + # (B, T, D) + if local_conditioning: + max_len = max([len(x[1]) for x in batch]) + c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], dtype=np.float32) + assert len(c_batch.shape) == 3 + # (B x C x T) + c_batch = torch.FloatTensor(c_batch).transpose(1, 2).contiguous() + else: + c_batch = None + + if global_conditioning: + g_batch = torch.LongTensor([x[2] for x in batch]) + else: + g_batch = None + + # Covnert to channel first i.e., (B, C, T) + x_batch = torch.FloatTensor(x_batch).transpose(1, 2).contiguous() + # Add extra axis + if is_mulaw_quantize(hparams.input_type): + y_batch = torch.LongTensor(y_batch).unsqueeze(-1).contiguous() + else: + y_batch = torch.FloatTensor(y_batch).unsqueeze(-1).contiguous() + + input_lengths = torch.LongTensor(input_lengths) + + return x_batch, y_batch, c_batch, g_batch, input_lengths + + +def time_string(): + return datetime.now().strftime('%Y-%m-%d %H:%M') + + +def save_waveplot(path, y_hat, y_target): + sr = hparams.sample_rate + + plt.figure(figsize=(16, 6)) + plt.subplot(2, 1, 1) + librosa.display.waveplot(y_target, sr=sr) + plt.subplot(2, 1, 2) + librosa.display.waveplot(y_hat, sr=sr) + plt.tight_layout() + plt.savefig(path, format="png") + plt.close() + + +def eval_model(global_step, writer, device, model, y, c, g, input_lengths, eval_dir, ema=None): + if ema is not None: + print("Using averaged model for evaluation") + model = clone_as_averaged_model(device, model, ema) + model.make_generation_fast_() + + model.eval() + idx = np.random.randint(0, len(y)) + length = input_lengths[idx].data.cpu().item() + + # (T,) + y_target = y[idx].view(-1).data.cpu().numpy()[:length] + + if c is not None: + if hparams.upsample_conditional_features: + c = c[idx, :, :length // audio.get_hop_size() + hparams.cin_pad * 2].unsqueeze(0) + else: + c = c[idx, :, :length].unsqueeze(0) + assert c.dim() == 3 + print("Shape of local conditioning features: {}".format(c.size())) + if g is not None: + # TODO: test + g = g[idx] + print("Shape of global conditioning features: {}".format(g.size())) + + # Dummy silence + if is_mulaw_quantize(hparams.input_type): + initial_value = P.mulaw_quantize(0, hparams.quantize_channels - 1) + elif is_mulaw(hparams.input_type): + initial_value = P.mulaw(0.0, hparams.quantize_channels) + else: + initial_value = 0.0 + + # (C,) + if is_mulaw_quantize(hparams.input_type): + initial_input = to_categorical( + initial_value, num_classes=hparams.quantize_channels).astype(np.float32) + initial_input = torch.from_numpy(initial_input).view( + 1, 1, hparams.quantize_channels) + else: + initial_input = torch.zeros(1, 1, 1).fill_(initial_value) + initial_input = initial_input.to(device) + + # Run the model in fast eval mode + with torch.no_grad(): + y_hat = model.incremental_forward( + initial_input, c=c, g=g, T=length, softmax=True, quantize=True, tqdm=tqdm, + log_scale_min=hparams.log_scale_min) + + if is_mulaw_quantize(hparams.input_type): + y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy() + y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1) + y_target = P.inv_mulaw_quantize(y_target, hparams.quantize_channels - 1) + elif is_mulaw(hparams.input_type): + y_hat = P.inv_mulaw(y_hat.view(-1).cpu().data.numpy(), hparams.quantize_channels) + y_target = P.inv_mulaw(y_target, hparams.quantize_channels) + else: + y_hat = y_hat.view(-1).cpu().data.numpy() + + # Save audio + os.makedirs(eval_dir, exist_ok=True) + path = join(eval_dir, "step{:09d}_predicted.wav".format(global_step)) + librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) + path = join(eval_dir, "step{:09d}_target.wav".format(global_step)) + librosa.output.write_wav(path, y_target, sr=hparams.sample_rate) + + # save figure + path = join(eval_dir, "step{:09d}_waveplots.png".format(global_step)) + save_waveplot(path, y_hat, y_target) + + +def save_states(global_step, writer, y_hat, y, input_lengths, checkpoint_dir=None): + print("Save intermediate states at step {}".format(global_step)) + idx = np.random.randint(0, len(y_hat)) + length = input_lengths[idx].data.cpu().item() + + # (B, C, T) + if y_hat.dim() == 4: + y_hat = y_hat.squeeze(-1) + + if is_mulaw_quantize(hparams.input_type): + # (B, T) + y_hat = F.softmax(y_hat, dim=1).max(1)[1] + + # (T,) + y_hat = y_hat[idx].data.cpu().long().numpy() + y = y[idx].view(-1).data.cpu().long().numpy() + + y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1) + y = P.inv_mulaw_quantize(y, hparams.quantize_channels - 1) + else: + # (B, T) + if hparams.output_distribution == "Logistic": + y_hat = sample_from_discretized_mix_logistic( + y_hat, log_scale_min=hparams.log_scale_min) + elif hparams.output_distribution == "Normal": + y_hat = sample_from_mix_gaussian( + y_hat, log_scale_min=hparams.log_scale_min) + else: + assert False + + # (T,) + y_hat = y_hat[idx].view(-1).data.cpu().numpy() + y = y[idx].view(-1).data.cpu().numpy() + + if is_mulaw(hparams.input_type): + y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels) + y = P.inv_mulaw(y, hparams.quantize_channels) + + # Mask by length + y_hat[length:] = 0 + y[length:] = 0 + + # Save audio + audio_dir = join(checkpoint_dir, "intermediate", "audio") + os.makedirs(audio_dir, exist_ok=True) + path = join(audio_dir, "step{:09d}_predicted.wav".format(global_step)) + librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) + path = join(audio_dir, "step{:09d}_target.wav".format(global_step)) + librosa.output.write_wav(path, y, sr=hparams.sample_rate) + +# workaround for https://github.com/pytorch/pytorch/issues/15716 +# the idea is to return outputs and replicas explicitly, so that making pytorch +# not to release the nodes (this is a pytorch bug though) + + +def data_parallel_workaround(model, input): + device_ids = list(range(torch.cuda.device_count())) + output_device = device_ids[0] + replicas = torch.nn.parallel.replicate(model, device_ids) + inputs = torch.nn.parallel.scatter(input, device_ids) + replicas = replicas[:len(inputs)] + outputs = torch.nn.parallel.parallel_apply(replicas, inputs) + y_hat = torch.nn.parallel.gather(outputs, output_device) + return y_hat, outputs, replicas + + +def __train_step(device, phase, epoch, global_step, global_test_step, + model, optimizer, writer, criterion, + x, y, c, g, input_lengths, + checkpoint_dir, eval_dir=None, do_eval=False, ema=None): + sanity_check(model, c, g) + + # x : (B, C, T) + # y : (B, T, 1) + # c : (B, C, T) + # g : (B,) + train = (phase == "train_no_dev") + clip_thresh = hparams.clip_thresh + if train: + model.train() + step = global_step + else: + model.eval() + step = global_test_step + + # Learning rate schedule + current_lr = hparams.optimizer_params["lr"] + if train and hparams.lr_schedule is not None: + lr_schedule_f = getattr(lrschedule, hparams.lr_schedule) + current_lr = lr_schedule_f( + hparams.optimizer_params["lr"], step, **hparams.lr_schedule_kwargs) + for param_group in optimizer.param_groups: + param_group['lr'] = current_lr + optimizer.zero_grad() + + # Prepare data + x, y = x.to(device), y.to(device) + input_lengths = input_lengths.to(device) + c = c.to(device) if c is not None else None + g = g.to(device) if g is not None else None + + # (B, T, 1) + mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1) + mask = mask[:, 1:, :] + + # Apply model: Run the model in regular eval mode + # NOTE: softmax is handled in F.cross_entrypy_loss + # y_hat: (B x C x T) + + if use_cuda: + # multi gpu support + # you must make sure that batch size % num gpu == 0 + y_hat, _outputs, _replicas = data_parallel_workaround(model, (x, c, g, False)) + else: + y_hat = model(x, c, g, False) + + if is_mulaw_quantize(hparams.input_type): + # wee need 4d inputs for spatial cross entropy loss + # (B, C, T, 1) + y_hat = y_hat.unsqueeze(-1) + loss = criterion(y_hat[:, :, :-1, :], y[:, 1:, :], mask=mask) + else: + loss = criterion(y_hat[:, :, :-1], y[:, 1:, :], mask=mask) + + if train and step > 0 and step % hparams.checkpoint_interval == 0: + save_states(step, writer, y_hat, y, input_lengths, checkpoint_dir) + save_checkpoint(device, model, optimizer, step, checkpoint_dir, epoch, ema) + + if do_eval: + # NOTE: use train step (i.e., global_step) for filename + eval_model(global_step, writer, device, model, y, c, g, input_lengths, eval_dir, ema) + + # Update + if train: + loss.backward() + if clip_thresh > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_thresh) + optimizer.step() + # update moving average + if ema is not None: + for name, param in model.named_parameters(): + if name in ema.shadow: + ema.update(name, param.data) + + # Logs + writer.add_scalar("{} loss".format(phase), float(loss.item()), step) + if train: + if clip_thresh > 0: + writer.add_scalar("gradient norm", grad_norm, step) + writer.add_scalar("learning rate", current_lr, step) + + return loss.item() + + +def train_loop(device, model, data_loaders, optimizer, writer, checkpoint_dir=None): + if is_mulaw_quantize(hparams.input_type): + criterion = MaskedCrossEntropyLoss() + else: + if hparams.output_distribution == "Logistic": + criterion = DiscretizedMixturelogisticLoss() + elif hparams.output_distribution == "Normal": + criterion = MixtureGaussianLoss() + else: + raise RuntimeError( + "Not supported output distribution type: {}".format( + hparams.output_distribution)) + + if hparams.exponential_moving_average: + ema = ExponentialMovingAverage(hparams.ema_decay) + for name, param in model.named_parameters(): + if param.requires_grad: + ema.register(name, param.data) + else: + ema = None + + global global_step, global_epoch, global_test_step + while global_epoch < hparams.nepochs: + for phase, data_loader in data_loaders.items(): + train = (phase == "train_no_dev") + running_loss = 0. + test_evaluated = False + for step, (x, y, c, g, input_lengths) in tqdm(enumerate(data_loader)): + # Whether to save eval (i.e., online decoding) result + do_eval = False + eval_dir = join(checkpoint_dir, "intermediate", "{}_eval".format(phase)) + # Do eval per eval_interval for train + if train and global_step > 0 \ + and global_step % hparams.train_eval_interval == 0: + do_eval = True + # Do eval for test + # NOTE: Decoding WaveNet is quite time consuming, so + # do only once in a single epoch for testset + if not train and not test_evaluated \ + and global_epoch % hparams.test_eval_epoch_interval == 0: + do_eval = True + test_evaluated = True + if do_eval: + print("[{}] Eval at train step {}".format(phase, global_step)) + + # Do step + running_loss += __train_step(device, + phase, global_epoch, global_step, global_test_step, model, + optimizer, writer, criterion, x, y, c, g, input_lengths, + checkpoint_dir, eval_dir, do_eval, ema) + + # update global state + if train: + global_step += 1 + else: + global_test_step += 1 + + if global_step >= hparams.max_train_steps: + print("Training reached max train steps ({}). will exit".format(hparams.max_train_steps)) + return ema + + # log per epoch + averaged_loss = running_loss / len(data_loader) + writer.add_scalar("{} loss (per epoch)".format(phase), + averaged_loss, global_epoch) + print("Step {} [{}] Loss: {}".format( + global_step, phase, running_loss / len(data_loader))) + + global_epoch += 1 + return ema + + +def save_checkpoint(device, model, optimizer, step, checkpoint_dir, epoch, ema=None): + checkpoint_path = join( + checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step)) + optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None + global global_test_step + torch.save({ + "state_dict": model.state_dict(), + "optimizer": optimizer_state, + "global_step": step, + "global_epoch": epoch, + "global_test_step": global_test_step, + }, checkpoint_path) + print("Saved checkpoint:", checkpoint_path) + + import shutil + latest_pth = join(checkpoint_dir, "checkpoint_latest.pth") + shutil.copyfile(checkpoint_path, latest_pth) + + if ema is not None: + averaged_model = clone_as_averaged_model(device, model, ema) + checkpoint_path = join( + checkpoint_dir, "checkpoint_step{:09d}_ema.pth".format(global_step)) + torch.save({ + "state_dict": averaged_model.state_dict(), + "optimizer": optimizer_state, + "global_step": step, + "global_epoch": epoch, + "global_test_step": global_test_step, + }, checkpoint_path) + print("Saved averaged checkpoint:", checkpoint_path) + + latest_pth = join(checkpoint_dir, "checkpoint_latest_ema.pth") + shutil.copyfile(checkpoint_path, latest_pth) + + +def build_model(): + if is_mulaw_quantize(hparams.input_type): + if hparams.out_channels != hparams.quantize_channels: + raise RuntimeError( + "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'") + if hparams.upsample_conditional_features and hparams.cin_channels < 0: + s = "Upsample conv layers were specified while local conditioning disabled. " + s += "Notice that upsample conv layers will never be used." + warn(s) + + upsample_params = hparams.upsample_params + upsample_params["cin_channels"] = hparams.cin_channels + upsample_params["cin_pad"] = hparams.cin_pad + model = WaveNet( + out_channels=hparams.out_channels, + layers=hparams.layers, + stacks=hparams.stacks, + residual_channels=hparams.residual_channels, + gate_channels=hparams.gate_channels, + skip_out_channels=hparams.skip_out_channels, + cin_channels=hparams.cin_channels, + gin_channels=hparams.gin_channels, + n_speakers=hparams.n_speakers, + dropout=hparams.dropout, + kernel_size=hparams.kernel_size, + cin_pad=hparams.cin_pad, + upsample_conditional_features=hparams.upsample_conditional_features, + upsample_params=upsample_params, + scalar_input=is_scalar_input(hparams.input_type), + output_distribution=hparams.output_distribution, + ) + return model + + +def _load(checkpoint_path): + if use_cuda: + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + + +def load_checkpoint(path, model, optimizer, reset_optimizer): + global global_step + global global_epoch + global global_test_step + + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + model.load_state_dict(checkpoint["state_dict"]) + if not reset_optimizer: + optimizer_state = checkpoint["optimizer"] + if optimizer_state is not None: + print("Load optimizer state from {}".format(path)) + optimizer.load_state_dict(checkpoint["optimizer"]) + global_step = checkpoint["global_step"] + global_epoch = checkpoint["global_epoch"] + global_test_step = checkpoint.get("global_test_step", 0) + + return model + + +# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3 +def restore_parts(path, model): + print("Restore part of the model from: {}".format(path)) + state = _load(path)["state_dict"] + model_dict = model.state_dict() + valid_state_dict = {k: v for k, v in state.items() if k in model_dict} + + try: + model_dict.update(valid_state_dict) + model.load_state_dict(model_dict) + except RuntimeError as e: + # there should be invalid size of weight(s), so load them per parameter + print(str(e)) + model_dict = model.state_dict() + for k, v in valid_state_dict.items(): + model_dict[k] = v + try: + model.load_state_dict(model_dict) + except RuntimeError as e: + print(str(e)) + warn("{}: may contain invalid size of weight. skipping...".format(k)) + + +def get_data_loaders(dump_root, speaker_id, test_shuffle=True): + data_loaders = {} + local_conditioning = hparams.cin_channels > 0 + + if hparams.max_time_steps is not None: + max_steps = ensure_divisible(hparams.max_time_steps, audio.get_hop_size(), True) + else: + max_steps = None + + for phase in ["train_no_dev", "dev"]: + train = phase == "train_no_dev" + X = FileSourceDataset( + RawAudioDataSource(join(dump_root, phase), speaker_id=speaker_id, + max_steps=max_steps, cin_pad=hparams.cin_pad, + hop_size=audio.get_hop_size())) + if local_conditioning: + Mel = FileSourceDataset( + MelSpecDataSource(join(dump_root, phase), speaker_id=speaker_id, + max_steps=max_steps, cin_pad=hparams.cin_pad, + hop_size=audio.get_hop_size())) + assert len(X) == len(Mel) + print("Local conditioning enabled. Shape of a sample: {}.".format( + Mel[0].shape)) + else: + Mel = None + print("[{}]: length of the dataset is {}".format(phase, len(X))) + + if train: + lengths = np.array(X.file_data_source.lengths) + # Prepare sampler + sampler = PartialyRandomizedSimilarTimeLengthSampler( + lengths, batch_size=hparams.batch_size) + shuffle = False + # make sure that there's no sorting bugs for https://github.com/r9y9/wavenet_vocoder/issues/130 + sampler_idx = np.asarray(sorted(list(map(lambda s: int(s), sampler)))) + assert (sampler_idx == np.arange(len(sampler_idx), dtype=np.int)).all() + else: + sampler = None + shuffle = test_shuffle + + dataset = PyTorchDataset(X, Mel) + data_loader = data_utils.DataLoader( + dataset, batch_size=hparams.batch_size, drop_last=True, + num_workers=hparams.num_workers, sampler=sampler, shuffle=shuffle, + collate_fn=collate_fn, pin_memory=hparams.pin_memory) + + speaker_ids = {} + if X.file_data_source.multi_speaker: + for idx, (x, c, g) in enumerate(dataset): + if g is not None: + try: + speaker_ids[g] += 1 + except KeyError: + speaker_ids[g] = 1 + if len(speaker_ids) > 0: + print("Speaker stats:", speaker_ids) + + data_loaders[phase] = data_loader + + return data_loaders + + +if __name__ == "__main__": + args = docopt(__doc__) + print("Command line args:\n", args) + checkpoint_dir = args["--checkpoint-dir"] + checkpoint_path = args["--checkpoint"] + checkpoint_restore_parts = args["--restore-parts"] + speaker_id = args["--speaker-id"] + speaker_id = int(speaker_id) if speaker_id is not None else None + preset = args["--preset"] + + dump_root = args["--dump-root"] + if dump_root is None: + dump_root = join(dirname(__file__), "data", "ljspeech") + + log_event_path = args["--log-event-path"] + reset_optimizer = args["--reset-optimizer"] + + # Load preset if specified + if preset is not None: + with open(preset) as f: + hparams.parse_json(f.read()) + # Override hyper parameters + hparams.parse(args["--hparams"]) + assert hparams.name == "wavenet_vocoder" + print(hparams_debug_string()) + + fs = hparams.sample_rate + + os.makedirs(checkpoint_dir, exist_ok=True) + + output_json_path = join(checkpoint_dir, "hparams.json") + with open(output_json_path, "w") as f: + json.dump(hparams.values(), f, indent=2) + + # Dataloader setup + data_loaders = get_data_loaders(dump_root, speaker_id, test_shuffle=True) + + maybe_set_epochs_based_on_max_steps(hparams, len(data_loaders["train_no_dev"])) + + device = torch.device("cuda" if use_cuda else "cpu") + + # Model + model = build_model().to(device) + + receptive_field = model.receptive_field + print("Receptive field (samples / ms): {} / {}".format( + receptive_field, receptive_field / fs * 1000)) + + from torch import optim + Optimizer = getattr(optim, hparams.optimizer) + optimizer = Optimizer(model.parameters(), **hparams.optimizer_params) + + if checkpoint_restore_parts is not None: + restore_parts(checkpoint_restore_parts, model) + + # Load checkpoints + if checkpoint_path is not None: + load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer) + + # Setup summary writer for tensorboard + if log_event_path is None: + log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_") + print("TensorBoard event log path: {}".format(log_event_path)) + writer = SummaryWriter(log_dir=log_event_path) + + # Train! + ema = None + try: + ema = train_loop(device, model, data_loaders, optimizer, writer, + checkpoint_dir=checkpoint_dir) + except KeyboardInterrupt: + print("Interrupted!") + pass + finally: + save_checkpoint( + device, model, optimizer, global_step, checkpoint_dir, global_epoch, ema) + + print("Finished") + + sys.exit(0) diff --git a/HuaWeiExperiment/wavenet/datasets/wavallin.py b/HuaWeiExperiment/wavenet/datasets/wavallin.py new file mode 100644 index 00000000..f5e25355 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/wavallin.py @@ -0,0 +1,121 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +''' +Process audio files for generating training and evaluating data. +''' +import os +from os.path import basename, splitext, join +import sys +from concurrent.futures import ProcessPoolExecutor +from functools import partial +from glob import glob +import audio +import librosa +from nnmnkwii import preprocessing as P +from hparams import hparams +from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw +import numpy as np +sys.path.append('.') + + +def build_from_path(in_dir, out_dir, num_workers=1, tqdm=lambda x: x): + '''processing audio files''' + executor = ProcessPoolExecutor(max_workers=num_workers) + futures = [] + index = 1 + src_files = sorted(glob(join(in_dir, "*.wav"))) + for wav_path in src_files: + futures.append(executor.submit( + partial(_process_utterance, out_dir, index, wav_path, "dummy"))) + index += 1 + + return [future.result() for future in tqdm(futures)] + + +def _process_utterance(out_dir, index, wav_path, text): + '''processing audio''' + # Load the audio to a numpy array: + wav = audio.load_wav(wav_path) + + # Trim begin/end silences + # NOTE: the threshold was chosen for clean signals + wav, _ = librosa.effects.trim(wav, top_db=60, frame_length=2048, hop_length=512) + + # Compute a mel-scale spectrogram from the trimmed wav: + # (N, D) + mel_spectrogram = audio.logmelspectrogram(wav).astype(np.float32).T + + # Mu-law quantize + if is_mulaw_quantize(hparams.input_type): + # Trim silences in mul-aw quantized domain + silence_threshold = 0 + if silence_threshold > 0: + # [0, quantize_channels) + out = P.mulaw_quantize(wav, hparams.quantize_channels - 1) + start, end = audio.start_and_end_indices(out, silence_threshold) + wav = wav[start:end] + constant_values = P.mulaw_quantize(0, hparams.quantize_channels - 1) + out_dtype = np.int16 + elif is_mulaw(hparams.input_type): + # [-1, 1] + constant_values = P.mulaw(0.0, hparams.quantize_channels - 1) + out_dtype = np.float32 + else: + # [-1, 1] + constant_values = 0.0 + out_dtype = np.float32 + + if hparams.global_gain_scale > 0: + wav *= hparams.global_gain_scale + + # Clip + if np.abs(wav).max() > 1.0: + print("""Warning: abs max value exceeds 1.0: {}""".format(np.abs(wav).max())) + # ignore this sample + return ("dummy", "dummy", -1, "dummy") + + # Set waveform target (out) + if is_mulaw_quantize(hparams.input_type): + out = P.mulaw_quantize(wav, hparams.quantize_channels - 1) + elif is_mulaw(hparams.input_type): + out = P.mulaw(wav, hparams.quantize_channels - 1) + else: + out = wav + + # zero pad + # this is needed to adjust time resolution between audio and mel-spectrogram + l, r = audio.pad_lr(out, hparams.fft_size, audio.get_hop_size()) + if l > 0 or r > 0: + out = np.pad(out, (l, r), mode="constant", constant_values=constant_values) + N = mel_spectrogram.shape[0] + assert len(out) >= N * audio.get_hop_size() + + # time resolution adjustment + # ensure length of raw audio is multiple of hop_size so that we can use + # transposed convolution to upsample + out = out[:N * audio.get_hop_size()] + assert len(out) % audio.get_hop_size() == 0 + + # Write the spectrograms to disk: + name = splitext(basename(wav_path))[0] + audio_filename = '%s-wave.npy' % name + mel_filename = '%s-feats.npy' % name + np.save(os.path.join(out_dir, audio_filename), + out.astype(out_dtype), allow_pickle=False) + np.save(os.path.join(out_dir, mel_filename), + mel_spectrogram.astype(np.float32), allow_pickle=False) + + # Return a tuple describing this training example: + return (audio_filename, mel_filename, N, text) diff --git a/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/__init__.py b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/__init__.py new file mode 100644 index 00000000..e34d0b96 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from __future__ import with_statement, print_function, absolute_import +from .wavenet import WaveNet diff --git a/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/conv.py b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/conv.py new file mode 100644 index 00000000..85feed02 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/conv.py @@ -0,0 +1,182 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Extended Conv1D.""" + +import math +import numpy as np +from mindspore import nn, Tensor +from mindspore.ops import operations as P +import mindspore.common.dtype as mstype +from mindspore import context + +class Conv1d(nn.Conv1d): + """ + Extended nn.Conv1d to adapt to incremental dilated convolutions. + During training, initial Conv1D is used and during evaluation, incremental_forward is called. + To improve the inference speed, tensor will be converted as numpy and the following calculation is based on numpy. + These operation will be replaced with MindSpore ops in the future. Currently, some operation is not supported by + MindSpore and a mixed use of numpy and MindSpore will take a long time. + + """ + + def __init__(self, *args, **kwargs): + super(Conv1d, self).__init__(*args, **kwargs) + self.clear_buffer() + self._linearized_weight = None + self.transpose_op = P.Transpose() + self.reshape_op = P.Reshape() + self.squeeze_op = P.Squeeze(-2) + self.zeros = P.Zeros() + self.concat_op = P.Concat(axis=1) + self.matmul = P.MatMul(transpose_b=True) + self.bias_add = P.BiasAdd() + self.get_weight = None + self.get_bias = None + + def incremental_forward(self, inputs, is_numpy=True): + if is_numpy: + return self.incremental_forward_numpy(inputs) + return self.incremental_forward_pynative(inputs) + + def incremental_forward_pynative(self, inputs): + """ + Incremental forward. + + Args: + inputs: B x T x C + + Returns: + ndarray + + """ + # input: (B, T, C) + if self.training: + raise RuntimeError('incremental_forward only supports eval mode') + + if self.get_weight is None: + self.get_weight = self._get_linearized_weight() + + if self.get_bias is None and self.bias is not None: + self.get_bias = self.bias + + # Note mindspore uses Conv2D to construct Conv1D + kw = self.kernel_size[1] + dilation = self.dilation[1] + + bsz = inputs.shape[0] # input: bsz x len x dim + if kw > 1: + if self.input_buffer is None: + init_buffer = self.zeros((bsz, kw + (kw - 1) * (dilation - 1), inputs.shape[2]), mstype.float32) + self.input_buffer = self.concat_op((init_buffer[:, 1:, :], inputs[:, 0:1, :])) + else: + # shift buffer + self.input_buffer = self.concat_op((self.input_buffer[:, 1:, :], inputs[:, 0:1, :])) + inputs = self.input_buffer + if dilation > 1: + if context.get_context("device_target") == "CPU": + inputs = self.transpose_op(inputs, (1, 0, 2)) + inputs = inputs[0::dilation, :, :] + inputs = self.transpose_op(inputs, (1, 0, 2)) + else: + inputs = inputs[:, 0::dilation, :] + + output = self.matmul(self.reshape_op(inputs, (bsz, -1)), self.get_weight) + if self.bias is not None: + output = self.bias_add(output, self.bias) + return self.reshape_op(output, (bsz, 1, -1)) + + def incremental_forward_numpy(self, inputs): + """ + Incremental forward. + + Args: + inputs: B x T x C + + Returns: + ndarray + + """ + # input: (B, T, C) + if self.training: + raise RuntimeError('incremental_forward only supports eval mode') + + if self.get_weight is None: + weight = self._get_linearized_weight() + self.get_weight = weight.asnumpy() + + if self.get_bias is None and self.bias is not None: + bias = self.bias + self.get_bias = bias.asnumpy() + + # Note mindspore uses Conv2D to construct Conv1D + kw = self.kernel_size[1] + dilation = self.dilation[1] + + bsz = inputs.shape[0] # input: bsz x len x dim + if kw > 1: + if self.input_buffer is None: + self.input_buffer = np.zeros((bsz, kw + (kw - 1) * (dilation - 1), inputs.shape[2]), dtype=np.float32) + else: + # shift buffer + self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :] + # append next + self.input_buffer[:, -1, :] = inputs[:, -1, :] + inputs = self.input_buffer + if dilation > 1: + inputs = inputs[:, 0::dilation, :] + output = inputs.reshape(bsz, -1).dot(self.get_weight.T) + if self.bias is not None: + output = output + np.expand_dims(self.get_bias, 0) + return np.reshape(output, (bsz, 1, -1)) + + def clear_buffer(self): + self.input_buffer = None + + def _get_linearized_weight(self): + """ + get linearized weight + """ + weight = self.squeeze_op(self.weight) + if self._linearized_weight is None: + # Note mindspore uses Conv2D to construct Conv1D + kw = self.kernel_size[1] + if weight.shape == (self.out_channels, self.in_channels, kw): + weight = self.transpose_op(weight, (0, 2, 1)) + else: + weight = self.transpose_op(weight, (2, 0, 1)) + self._linearized_weight = self.reshape_op(weight, (self.out_channels, -1)) + return self._linearized_weight + + def _clear_linearized_weight(self, *args): + self._linearized_weight = None + + def _initialize_weights(self): + """ + weight initialization + """ + self.init_parameters_data() + std_mul = 4.0 + for _, m in self.cells_and_names(): + if isinstance(m, nn.Conv1d): + std = math.sqrt((std_mul * 0.1) / (m.kernel_size[1] * self.in_channels)) + m.weight.set_data(Tensor(np.random.normal(0, std, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_data( + Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_data( + Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) diff --git a/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/mixture.py b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/mixture.py new file mode 100644 index 00000000..a594fdd4 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/mixture.py @@ -0,0 +1,386 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Loss function for training and sample function for testing. +""" +import numpy as np +import mindspore as ms +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.ops as P +from mindspore import context + + +class log_sum_exp(nn.Cell): + """Numerically stable log_sum_exp + """ + + def __init__(self): + super(log_sum_exp, self).__init__() + self.maxi = P.ReduceMax() + self.maxi_dim = P.ReduceMax(keep_dims=True) + self.log = P.Log() + self.sums = P.ReduceSum() + self.exp = P.Exp() + + def construct(self, x): + axis = len(x.shape) - 1 + m = self.maxi(x, axis) + m2 = self.maxi_dim(x, axis) + return m + self.log(self.sums(self.exp(x - m2), axis)) + + +class log_softmax(nn.Cell): + """ + replacement of P.LogSoftmax(-1) in CPU mode + only support x.shape == 2 or 3 + """ + + def __init__(self): + super(log_softmax, self).__init__() + self.maxi = P.ReduceMax() + self.log = P.Log() + self.sums = P.ReduceSum() + self.exp = P.Exp() + self.axis = -1 + self.concat = P.Concat(-1) + self.expanddims = P.ExpandDims() + + def construct(self, x): + """ + + Args: + x (Tensor): input + + Returns: + Tensor: log_softmax of input + + """ + c = self.maxi(x, self.axis) + logs, lsm = None, None + if len(x.shape) == 2: + for j in range(x.shape[-1]): + temp = self.expanddims(self.exp(x[:, j] - c), -1) + logs = temp if j == 0 else self.concat((logs, temp)) + sums = self.sums(logs, -1) + for i in range(x.shape[-1]): + temp = self.expanddims(x[:, i] - c - self.log(sums), -1) + lsm = temp if i == 0 else self.concat((lsm, temp)) + return lsm + if len(x.shape) == 3: + for j in range(x.shape[-1]): + temp = self.expanddims(self.exp(x[:, :, j] - c), -1) + logs = temp if j == 0 else self.concat((logs, temp)) + sums = self.sums(logs, -1) + for i in range(x.shape[-1]): + temp = self.expanddims(x[:, :, i] - c - self.log(sums), -1) + lsm = temp if i == 0 else self.concat((lsm, temp)) + return lsm + return None + + +class Stable_softplus(nn.Cell): + """Numerically stable softplus + """ + + def __init__(self): + super(Stable_softplus, self).__init__() + self.log_op = P.Log() + self.abs_op = P.Abs() + self.relu_op = P.ReLU() + self.exp_op = P.Exp() + + def construct(self, x): + return self.log_op(1 + self.exp_op(- self.abs_op(x))) + self.relu_op(x) + + +class discretized_mix_logistic_loss(nn.Cell): + """ + Discretized_mix_logistic_loss + + Args: + num_classes (int): Num_classes + log_scale_min (float): Log scale minimum value + + """ + + def __init__(self, num_classes=256, log_scale_min=-7.0, reduce=True): + super(discretized_mix_logistic_loss, self).__init__() + self.num_classes = num_classes + self.log_scale_min = log_scale_min + self.reduce = reduce + self.transpose_op = P.Transpose() + self.exp = P.Exp() + self.sigmoid = P.Sigmoid() + self.softplus = Stable_softplus() + self.log = P.Log() + self.cast = P.Cast() + self.expand_dims = P.ExpandDims() + self.tile = P.Tile() + self.maximum = P.Maximum() + self.sums = P.ReduceSum() + self.lse = log_sum_exp() + self.reshape = P.Reshape() + self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32)) + self.tensor_one = Tensor(1., ms.float32) + + if context.get_context("device_target") == "CPU": + self.logsoftmax = log_softmax() + else: + self.logsoftmax = P.LogSoftmax(-1) + + def construct(self, y_hat, y): + """ + + Args: + y_hat (Tensor): Predicted distribution + y (Tensor): Target + + Returns: + Tensor: Discretized_mix_logistic_loss + + """ + nr_mix = y_hat.shape[1] // 3 + + # (B x T x C) + y_hat = self.transpose_op(y_hat, (0, 2, 1)) + + # (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix:2 * nr_mix] + min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix)) + log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut) + + # B x T x 1 -> B x T x num_mixtures + y = self.tile(y, (1, 1, nr_mix)) + + centered_y = y - means + inv_stdv = self.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1. / (self.num_classes - 1)) + cdf_plus = self.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1. / (self.num_classes - 1)) + cdf_min = self.sigmoid(min_in) + + log_cdf_plus = plus_in - self.softplus(plus_in) + + log_one_minus_cdf_min = -self.softplus(min_in) + + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + log_pdf_mid = mid_in - log_scales - 2. * self.softplus(mid_in) + + inner_inner_cond = self.cast(cdf_delta > 1e-5, ms.float32) + min_cut2 = 1e-12 * self.tile(self.tensor_one, cdf_delta.shape) + inner_inner_out = inner_inner_cond * \ + self.log(self.maximum(cdf_delta, min_cut2)) + \ + (1. - inner_inner_cond) * (log_pdf_mid - self.factor) + inner_cond = self.cast(y > 0.999, ms.float32) + inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out + cond = self.cast(y < -0.999, ms.float32) + log_probs = cond * log_cdf_plus + (1. - cond) * inner_out + + a, b, c = logit_probs.shape[0], logit_probs.shape[1], logit_probs.shape[2] + logit_probs = self.logsoftmax(self.reshape(logit_probs, (-1, c))) + logit_probs = self.reshape(logit_probs, (a, b, c)) + + log_probs = log_probs + logit_probs + if self.reduce: + return -self.sums(self.lse(log_probs)) + return self.expand_dims(-self.lse(log_probs), -1) + + +def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0): + """ + Sample from discretized mixture of logistic distributions + + Args: + y (ndarray): B x C x T + log_scale_min (float): Log scale minimum value + + Returns: + ndarray + """ + nr_mix = y.shape[1] // 3 + + # B x T x C + y = np.transpose(y, (0, 2, 1)) + logit_probs = y[:, :, :nr_mix] + + temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape) + temp = logit_probs - np.log(- np.log(temp)) + + argmax = np.argmax(temp, axis=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = np.eye(nr_mix)[argmax] + means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1) + log_scales = np.clip(np.sum( + y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1), a_min=log_scale_min, a_max=None) + + u = np.random.uniform(1e-5, 1.0 - 1e-5, means.shape) + x = means + np.exp(log_scales) * (np.log(u) - np.log(1. - u)) + x = np.clip(x, -1., 1.) + return x.astype(np.float32) + + +class mix_gaussian_loss(nn.Cell): + """ + Mix gaussian loss + """ + + def __init__(self, log_scale_min=-7.0, reduce=True): + super(mix_gaussian_loss, self).__init__() + self.log_scale_min = log_scale_min + self.reduce = reduce + self.transpose_op = P.Transpose() + self.maximum = P.Maximum() + self.tile = P.Tile() + self.exp = P.Exp() + self.expand_dims = P.ExpandDims() + self.sums = P.ReduceSum() + self.lse = log_sum_exp() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.const = P.ScalarToTensor() + self.log = P.Log() + self.tensor_one = Tensor(1., ms.float32) + + if context.get_context("device_target") == "CPU": + self.logsoftmax = log_softmax() + else: + self.logsoftmax = P.LogSoftmax(-1) + + def construct(self, y_hat, y): + """ + + Args: + y_hat (Tensor): Predicted probability + y (Tensor): Target + + Returns: + Tensor: Mix_gaussian_loss + + """ + C = y_hat.shape[1] + if C == 2: + nr_mix = 1 + else: + nr_mix = y_hat.shape[1] // 3 + + # (B x T x C) + y_hat = self.transpose_op(y_hat, (0, 2, 1)) + + if C == 2: + logit_probs = None + means = y_hat[:, :, 0:1] + min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], 1)) + log_scales = self.maximum(y_hat[:, :, 1:2], min_cut) + else: + # (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix:2 * nr_mix] + min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix)) + log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut) + + # B x T x 1 -> B x T x num_mixtures + y = self.tile(y, (1, 1, nr_mix)) + centered_y = y - means + + sd = self.exp(log_scales) + unnormalized_log_prob = -1. * (self.sq(centered_y - 0.)) / (2. * self.sq(sd)) + neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) + log_probs = unnormalized_log_prob + neg_normalization + + if nr_mix > 1: + log_probs = log_probs + self.logsoftmax(logit_probs) + + if self.reduce: + if nr_mix == 1: + return -self.sums(log_probs) + return -self.sums(self.lse(log_probs)) + if nr_mix == 1: + return -log_probs + return self.expand_dims(-self.lse(log_probs), -1) + + +def sample_from_mix_gaussian(y, log_scale_min=-7.0): + """ + Sample_from_mix_gaussian + + Args: + y (ndarray): B x C x T + + Returns: + ndarray + + """ + C = y.shape[1] + if C == 2: + nr_mix = 1 + else: + nr_mix = y.shape[1] // 3 + + # B x T x C + y = np.transpose(y, (0, 2, 1)) + + if C == 2: + logit_probs = None + else: + logit_probs = y[:, :, :nr_mix] + + if nr_mix > 1: + temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape) + temp = logit_probs - np.log(- np.log(temp)) + argmax = np.argmax(temp, axis=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = np.eye(nr_mix)[argmax] + + means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1) + + log_scales = np.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1) + else: + if C == 2: + means, log_scales = y[:, :, 0], y[:, :, 1] + elif C == 3: + means, log_scales = y[:, :, 1], y[:, :, 2] + else: + assert False, "shouldn't happen" + + scales = np.exp(log_scales) + x = np.random.normal(loc=means, scale=scales) + x = np.clip(x, -1., 1.) + return x.astype(np.float32) + + +# self-implemented onehotcategorical distribution +# https://zhuanlan.zhihu.com/p/59550457 +def sample_from_mix_onehotcategorical(x): + """ + Sample_from_mix_onehotcategorical + + Args: + x (ndarray): Predicted softmax probability + + Returns: + ndarray + + """ + pi = np.log(x) + u = np.random.uniform(0, 1, x.shape) + g = -np.log(-np.log(u)) + c = np.argmax(pi + g, axis=1) + return np.array(np.eye(256)[c], dtype=np.float32) diff --git a/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/modules.py b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/modules.py new file mode 100644 index 00000000..208049c7 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/modules.py @@ -0,0 +1,213 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Modules for WaveNet. +""" +from __future__ import with_statement, print_function, absolute_import +import math +import numpy as np +from wavenet_vocoder import conv +from mindspore import nn +from mindspore.ops import operations as P + + +def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): + m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs) + return m + + +def Conv1d1x1(in_channels, out_channels, has_bias=True): + return Conv1d(in_channels, out_channels, kernel_size=1, pad_mode='pad', padding=0, dilation=1, has_bias=has_bias) + + +def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + return m + + +def _conv1x1_forward(conv_, x, is_incremental, is_numpy=True): + """ + Conv1x1 forward + """ + if is_incremental: + x = conv_.incremental_forward(x, is_numpy=is_numpy) + else: + x = conv_(x) + return x + + +class ResidualConv1dGLU(nn.Cell): + """Residual dilated conv1d with gated activation units + + Args: + residual_channels (int): Residual input / output channels + gate_channels (int): Gated activation channels. + kernel_size (int): Kernel size + skip_out_channels (int): Skip connection channels. If None, it will set to the same as residual_channels. + cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled. + gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled. + dropout (float): Dropout rate. + padding (int): Padding for convolution layers. If None, padding value will be computed according to dilation + and kernel_size. + dilation (int): Dilation factor. + + """ + + def __init__(self, residual_channels=None, gate_channels=None, kernel_size=None, skip_out_channels=None, bias=True, + dropout=1 - 0.95, dilation=1, cin_channels=-1, gin_channels=-1, padding=None, causal=True): + super(ResidualConv1dGLU, self).__init__() + self.dropout = dropout + self.dropout_op = nn.Dropout(p=self.dropout) + self.eval_split_op = P.Split(axis=-1, output_num=2) + self.train_split_op = P.Split(axis=1, output_num=2) + self.tanh = P.Tanh() + self.sigmoid = P.Sigmoid() + self.mul = P.Mul() + self.add = P.Add() + + if skip_out_channels is None: + skip_out_channels = residual_channels + if padding is None: + if causal: + padding = (kernel_size - 1) * dilation + else: + padding = (kernel_size - 1) // 2 * dilation + self.causal = causal + + self.conv = Conv1d(residual_channels, gate_channels, kernel_size, pad_mode='pad', + padding=padding, dilation=dilation, has_bias=bias) + + # local conditioning + if cin_channels > 0: + self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, has_bias=False) + else: + self.conv1x1c = None + + # global conditioning + if gin_channels > 0: + self.conv1x1g = Conv1d(gin_channels, gate_channels, has_bias=False, kernel_size=1, dilation=1) + else: + self.conv1x1g = None + + gate_out_channels = gate_channels // 2 + self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, has_bias=bias) + self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, has_bias=bias) + self.factor = math.sqrt(0.5) + + def construct(self, x, c=None, g=None): + """ + + Args: + x(Tensor): One-hot audio signal, the shape is B x C x T + c(Tensor): local conditional feature, the shape is B x cin_channels x T + g(Tensor): global conditional feature, not used currently + + Returns: + Tensor: Output tensor + + """ + + residual = x + x = self.dropout_op(x) + x = self.conv(x) + # remove future time steps + x = x[:, :, :residual.shape[-1]] if self.causal else x + split_op = self.train_split_op + + a, b = split_op(x) + + # local conditioning + if c is not None: + c = _conv1x1_forward(self.conv1x1c, c, is_incremental=False) + ca, cb = split_op(c) + a, b = a + ca, b + cb + + # global conditioning + if g is not None: + g = _conv1x1_forward(self.conv1x1g, g, is_incremental=False) + ga, gb = self.split(g) + a, b = a + ga, b + gb + + x = self.mul(self.tanh(a), self.sigmoid(b)) + + # For skip connection + s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental=False) + + # For residual connection + x = _conv1x1_forward(self.conv1x1_out, x, is_incremental=False) + + x = self.add(x, residual) * self.factor + return x, s + + def sigmoid_numpy(self, x): + return 1. / (1 + np.exp(-x)) + + def incremental_forward(self, x, c=None, g=None, is_numpy=True): + """ + Incremental forward. Used for inference stage + + Args: + x (Tensor): One-hot audio signal, the shape is B x C x T + c (Tensor): local conditional feature, the shape is B x cin_channels x T + g (Tensor): global conditional feature, not used currently + + Returns: + ndarray + """ + residual = x + x = self.conv.incremental_forward(x, is_numpy=is_numpy) + if is_numpy: + a, b = np.split(x, indices_or_sections=2, axis=-1) + else: + a, b = self.eval_split_op(x) + + # local conditioning + if c is not None: + c = _conv1x1_forward(self.conv1x1c, c, is_incremental=True, is_numpy=is_numpy) + if is_numpy: + ca, cb = np.split(c, indices_or_sections=2, axis=-1) + else: + ca, cb = self.eval_split_op(c) + a, b = a + ca, b + cb + + # global conditioning + if g is not None: + g = _conv1x1_forward(self.conv1x1g, g, is_incremental=True, is_numpy=is_numpy) + if is_numpy: + ga, gb = np.split(g, indices_or_sections=2, axis=-1) + else: + ga, gb = self.eval_split_op(c) + a, b = a + ga, b + gb + + if is_numpy: + x = np.tanh(a) * self.sigmoid_numpy(b) + else: + x = self.mul(self.tanh(a), self.sigmoid(b)) + + # For skip connection + s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental=True, is_numpy=is_numpy) + + # For residual connection + x = _conv1x1_forward(self.conv1x1_out, x, is_incremental=True, is_numpy=is_numpy) + + x = (x + residual) * self.factor + return x, s + + def clear_buffer(self): + """clear buffer""" + for c in [self.conv, self.conv1x1_out, self.conv1x1_skip, + self.conv1x1c, self.conv1x1g]: + if c is not None: + c.clear_buffer() diff --git a/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/tfcompat/__init__.py b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/tfcompat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/tfcompat/hparam.py b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/tfcompat/hparam.py new file mode 100644 index 00000000..c428176b --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/tfcompat/hparam.py @@ -0,0 +1,726 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hyperparameter values.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import numbers +import re + +import six + +## from tensorflow.contrib.training.python.training import hparam_pb2 +## from tensorflow.python.framework import ops +## from tensorflow.python.util import compat +## from tensorflow.python.util import deprecation + +# Define the regular expression for parsing a single clause of the input +# (delimited by commas). A legal clause looks like: +# []? = +# where is either a single token or [] enclosed list of tokens. +# For example: "var[1] = a" or "x = [1,2,3]" +PARAM_RE = re.compile(r""" + (?P[a-zA-Z][\w\.]*) # variable name: "var" or "x" + (\[\s*(?P\d+)\s*\])? # (optional) index: "1" or None + \s*=\s* + ((?P[^,\[]*) # single value: "a" or None + | + \[(?P[^\]]*)\]) # list of values: None or "1,2,3" + ($|,\s*)""", re.VERBOSE) + + +def _parse_fail(name, var_type, value, values): + """Helper function for raising a value error for bad assignment.""" + raise ValueError( + 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' % + (name, var_type.__name__, value, values)) + + +def _reuse_fail(name, values): + """Helper function for raising a value error for reuse of name.""" + raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name, + values)) + + +def _process_scalar_value(name, parse_fn, var_type, m_dict, values, + results_dictionary): + """Update results_dictionary with a scalar value. + + Used to update the results_dictionary to be returned by parse_values when + encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) + + Mutates results_dictionary. + + Args: + name: Name of variable in assignment ("s" or "arr"). + parse_fn: Function for parsing the actual value. + var_type: Type of named variable. + m_dict: Dictionary constructed from regex parsing. + m_dict['val']: RHS value (scalar) + m_dict['index']: List index value (or None) + values: Full expression being parsed + results_dictionary: The dictionary being updated for return by the parsing + function. + + Raises: + ValueError: If the name has already been used. + """ + try: + parsed_value = parse_fn(m_dict['val']) + except ValueError: + _parse_fail(name, var_type, m_dict['val'], values) + + # If no index is provided + if not m_dict['index']: + if name in results_dictionary: + _reuse_fail(name, values) + results_dictionary[name] = parsed_value + else: + if name in results_dictionary: + # The name has already been used as a scalar, then it + # will be in this dictionary and map to a non-dictionary. + if not isinstance(results_dictionary.get(name), dict): + _reuse_fail(name, values) + else: + results_dictionary[name] = {} + + index = int(m_dict['index']) + # Make sure the index position hasn't already been assigned a value. + if index in results_dictionary[name]: + _reuse_fail('{}[{}]'.format(name, index), values) + results_dictionary[name][index] = parsed_value + + +def _process_list_value(name, parse_fn, var_type, m_dict, values, + results_dictionary): + """Update results_dictionary from a list of values. + + Used to update results_dictionary to be returned by parse_values when + encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) + + Mutates results_dictionary. + + Args: + name: Name of variable in assignment ("arr"). + parse_fn: Function for parsing individual values. + var_type: Type of named variable. + m_dict: Dictionary constructed from regex parsing. + m_dict['val']: RHS value (scalar) + values: Full expression being parsed + results_dictionary: The dictionary being updated for return by the parsing + function. + + Raises: + ValueError: If the name has an index or the values cannot be parsed. + """ + if m_dict['index'] is not None: + raise ValueError('Assignment of a list to a list index.') + elements = filter(None, re.split('[ ,]', m_dict['vals'])) + # Make sure the name hasn't already been assigned a value + if name in results_dictionary: + raise _reuse_fail(name, values) + try: + results_dictionary[name] = [parse_fn(e) for e in elements] + except ValueError: + _parse_fail(name, var_type, m_dict['vals'], values) + + +def _cast_to_type_if_compatible(name, param_type, value): + """Cast hparam to the provided type, if compatible. + + Args: + name: Name of the hparam to be cast. + param_type: The type of the hparam. + value: The value to be cast, if compatible. + + Returns: + The result of casting `value` to `param_type`. + + Raises: + ValueError: If the type of `value` is not compatible with param_type. + * If `param_type` is a string type, but `value` is not. + * If `param_type` is a boolean, but `value` is not, or vice versa. + * If `param_type` is an integer type, but `value` is not. + * If `param_type` is a float type, but `value` is not a numeric type. + """ + fail_msg = ( + "Could not cast hparam '%s' of type '%s' from value %r" % + (name, param_type, value)) + + # Some callers use None, for which we can't do any casting/checking. :( + if issubclass(param_type, type(None)): + return value + + # Avoid converting a non-string type to a string. + if (issubclass(param_type, (six.string_types, six.binary_type)) and + not isinstance(value, (six.string_types, six.binary_type))): + raise ValueError(fail_msg) + + # Avoid converting a number or string type to a boolean or vice versa. + if issubclass(param_type, bool) != isinstance(value, bool): + raise ValueError(fail_msg) + + # Avoid converting float to an integer (the reverse is fine). + if (issubclass(param_type, numbers.Integral) and + not isinstance(value, numbers.Integral)): + raise ValueError(fail_msg) + + # Avoid converting a non-numeric type to a numeric type. + if (issubclass(param_type, numbers.Number) and + not isinstance(value, numbers.Number)): + raise ValueError(fail_msg) + + return param_type(value) + + +def parse_values(values, type_map): + """Parses hyperparameter values from a string into a python map. + + `values` is a string containing comma-separated `name=value` pairs. + For each pair, the value of the hyperparameter named `name` is set to + `value`. + + If a hyperparameter name appears multiple times in `values`, a ValueError + is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). + + If a hyperparameter name in both an index assignment and scalar assignment, + a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). + + The hyperparameter name may contain '.' symbols, which will result in an + attribute name that is only accessible through the getattr and setattr + functions. (And must be first explicit added through add_hparam.) + + WARNING: Use of '.' in your variable names is allowed, but is not well + supported and not recommended. + + The `value` in `name=value` must follows the syntax according to the + type of the parameter: + + * Scalar integer: A Python-parsable integer point value. E.g.: 1, + 100, -12. + * Scalar float: A Python-parsable floating point value. E.g.: 1.0, + -.54e89. + * Boolean: Either true or false. + * Scalar string: A non-empty sequence of characters, excluding comma, + spaces, and square brackets. E.g.: foo, bar_1. + * List: A comma separated list of scalar values of the parameter type + enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. + + When index assignment is used, the corresponding type_map key should be the + list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not + "arr[1]"). + + Args: + values: String. Comma separated list of `name=value` pairs where + 'value' must follow the syntax described above. + type_map: A dictionary mapping hyperparameter names to types. Note every + parameter name in values must be a key in type_map. The values must + conform to the types indicated, where a value V is said to conform to a + type T if either V has type T, or V is a list of elements of type T. + Hence, for a multidimensional parameter 'x' taking float values, + 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. + + Returns: + A python map mapping each name to either: + * A scalar value. + * A list of scalar values. + * A dictionary mapping index numbers to scalar values. + (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") + + Raises: + ValueError: If there is a problem with input. + * If `values` cannot be parsed. + * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). + * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', + 'a[1]=1,a[1]=2', or 'a=1,a=[1]') + """ + results_dictionary = {} + pos = 0 + while pos < len(values): + m = PARAM_RE.match(values, pos) + if not m: + raise ValueError('Malformed hyperparameter value: %s' % values[pos:]) + # Check that there is a comma between parameters and move past it. + pos = m.end() + # Parse the values. + m_dict = m.groupdict() + name = m_dict['name'] + if name not in type_map: + raise ValueError('Unknown hyperparameter type for %s' % name) + type_ = type_map[name] + + # Set up correct parsing function (depending on whether type_ is a bool) + if type_ == bool: + + def parse_bool(value): + if value in ['true', 'True']: + return True + elif value in ['false', 'False']: + return False + else: + try: + return bool(int(value)) + except ValueError: + _parse_fail(name, type_, value, values) + + parse = parse_bool + else: + parse = type_ + + # If a singe value is provided + if m_dict['val'] is not None: + _process_scalar_value(name, parse, type_, m_dict, values, + results_dictionary) + + # If the assigned value is a list: + elif m_dict['vals'] is not None: + _process_list_value(name, parse, type_, m_dict, values, + results_dictionary) + + else: # Not assigned a list or value + _parse_fail(name, type_, '', values) + + return results_dictionary + + +class HParams(object): + """Class to hold a set of hyperparameters as name-value pairs. + + A `HParams` object holds hyperparameters used to build and train a model, + such as the number of hidden units in a neural net layer or the learning rate + to use when training. + + You first create a `HParams` object by specifying the names and values of the + hyperparameters. + + To make them easily accessible the parameter names are added as direct + attributes of the class. A typical usage is as follows: + + ```python + # Create a HParams object specifying names and values of the model + # hyperparameters: + hparams = HParams(learning_rate=0.1, num_hidden_units=100) + + # The hyperparameter are available as attributes of the HParams object: + hparams.learning_rate ==> 0.1 + hparams.num_hidden_units ==> 100 + ``` + + Hyperparameters have type, which is inferred from the type of their value + passed at construction type. The currently supported types are: integer, + float, boolean, string, and list of integer, float, boolean, or string. + + You can override hyperparameter values by calling the + [`parse()`](#HParams.parse) method, passing a string of comma separated + `name=value` pairs. This is intended to make it possible to override + any hyperparameter values from a single command-line flag to which + the user passes 'hyper-param=value' pairs. It avoids having to define + one flag for each hyperparameter. + + The syntax expected for each value depends on the type of the parameter. + See `parse()` for a description of the syntax. + + Example: + + ```python + # Define a command line flag to pass name=value pairs. + # For example using argparse: + import argparse + parser = argparse.ArgumentParser(description='Train my model.') + parser.add_argument('--hparams', type=str, + help='Comma separated list of "name=value" pairs.') + args = parser.parse_args() + ... + def my_program(): + # Create a HParams object specifying the names and values of the + # model hyperparameters: + hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, + activations=['relu', 'tanh']) + + # Override hyperparameters values by parsing the command line + hparams.parse(args.hparams) + + # If the user passed `--hparams=learning_rate=0.3` on the command line + # then 'hparams' has the following attributes: + hparams.learning_rate ==> 0.3 + hparams.num_hidden_units ==> 100 + hparams.activations ==> ['relu', 'tanh'] + + # If the hyperparameters are in json format use parse_json: + hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') + ``` + """ + + _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. + + def __init__(self, hparam_def=None, model_structure=None, **kwargs): + """Create an instance of `HParams` from keyword arguments. + + The keyword arguments specify name-values pairs for the hyperparameters. + The parameter types are inferred from the type of the values passed. + + The parameter names are added as attributes of `HParams` object, so they + can be accessed directly with the dot notation `hparams._name_`. + + Example: + + ```python + # Define 3 hyperparameters: 'learning_rate' is a float parameter, + # 'num_hidden_units' an integer parameter, and 'activation' a string + # parameter. + hparams = tf.HParams( + learning_rate=0.1, num_hidden_units=100, activation='relu') + + hparams.activation ==> 'relu' + ``` + + Note that a few names are reserved and cannot be used as hyperparameter + names. If you use one of the reserved name the constructor raises a + `ValueError`. + + Args: + hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef + protocol buffer. If provided, this object is initialized by + deserializing hparam_def. Otherwise **kwargs is used. + model_structure: An instance of ModelStructure, defining the feature + crosses to be used in the Trial. + **kwargs: Key-value pairs where the key is the hyperparameter name and + the value is the value for the parameter. + + Raises: + ValueError: If both `hparam_def` and initialization values are provided, + or if one of the arguments is invalid. + + """ + # Register the hyperparameters and their type in _hparam_types. + # This simplifies the implementation of parse(). + # _hparam_types maps the parameter name to a tuple (type, bool). + # The type value is the type of the parameter for scalar hyperparameters, + # or the type of the list elements for multidimensional hyperparameters. + # The bool value is True if the value is a list, False otherwise. + self._hparam_types = {} + self._model_structure = model_structure + if hparam_def: +## self._init_from_proto(hparam_def) +## if kwargs: +## raise ValueError('hparam_def and initialization values are ' +## 'mutually exclusive') + raise ValueError('hparam_def has been disabled in this version') + else: + for name, value in six.iteritems(kwargs): + self.add_hparam(name, value) + +## def _init_from_proto(self, hparam_def): +## """Creates a new HParams from `HParamDef` protocol buffer. +## +## Args: +## hparam_def: `HParamDef` protocol buffer. +## """ +## assert isinstance(hparam_def, hparam_pb2.HParamDef) +## for name, value in hparam_def.hparam.items(): +## kind = value.WhichOneof('kind') +## if kind.endswith('_value'): +## # Single value. +## if kind.startswith('int64'): +## # Setting attribute value to be 'int' to ensure the type is compatible +## # with both Python2 and Python3. +## self.add_hparam(name, int(getattr(value, kind))) +## elif kind.startswith('bytes'): +## # Setting attribute value to be 'str' to ensure the type is compatible +## # with both Python2 and Python3. UTF-8 encoding is assumed. +## self.add_hparam(name, compat.as_str(getattr(value, kind))) +## else: +## self.add_hparam(name, getattr(value, kind)) +## else: +## # List of values. +## if kind.startswith('int64'): +## # Setting attribute value to be 'int' to ensure the type is compatible +## # with both Python2 and Python3. +## self.add_hparam(name, [int(v) for v in getattr(value, kind).value]) +## elif kind.startswith('bytes'): +## # Setting attribute value to be 'str' to ensure the type is compatible +## # with both Python2 and Python3. UTF-8 encoding is assumed. +## self.add_hparam( +## name, [compat.as_str(v) for v in getattr(value, kind).value]) +## else: +## self.add_hparam(name, [v for v in getattr(value, kind).value]) + + def add_hparam(self, name, value): + """Adds {name, value} pair to hyperparameters. + + Args: + name: Name of the hyperparameter. + value: Value of the hyperparameter. Can be one of the following types: + int, float, string, int list, float list, or string list. + + Raises: + ValueError: if one of the arguments is invalid. + """ + # Keys in kwargs are unique, but 'name' could the name of a pre-existing + # attribute of this object. In that case we refuse to use it as a + # hyperparameter name. + if getattr(self, name, None) is not None: + raise ValueError('Hyperparameter name is reserved: %s' % name) + if isinstance(value, (list, tuple)): + if not value: + raise ValueError( + 'Multi-valued hyperparameters cannot be empty: %s' % name) + self._hparam_types[name] = (type(value[0]), True) + else: + self._hparam_types[name] = (type(value), False) + setattr(self, name, value) + + def set_hparam(self, name, value): + """Set the value of an existing hyperparameter. + + This function verifies that the type of the value matches the type of the + existing hyperparameter. + + Args: + name: Name of the hyperparameter. + value: New value of the hyperparameter. + + Raises: + ValueError: If there is a type mismatch. + """ + param_type, is_list = self._hparam_types[name] + if isinstance(value, list): + if not is_list: + raise ValueError( + 'Must not pass a list for single-valued parameter: %s' % name) + setattr(self, name, [ + _cast_to_type_if_compatible(name, param_type, v) for v in value]) + else: + if is_list: + raise ValueError( + 'Must pass a list for multi-valued parameter: %s.' % name) + setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) + + def del_hparam(self, name): + """Removes the hyperparameter with key 'name'. + + Args: + name: Name of the hyperparameter. + """ + if hasattr(self, name): + delattr(self, name) + del self._hparam_types[name] + + def parse(self, values): + """Override hyperparameter values, parsing new values from a string. + + See parse_values for more detail on the allowed format for values. + + Args: + values: String. Comma separated list of `name=value` pairs where + 'value' must follow the syntax described above. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values` cannot be parsed. + """ + type_map = dict() + for name, t in self._hparam_types.items(): + param_type, _ = t + type_map[name] = param_type + + values_map = parse_values(values, type_map) + return self.override_from_dict(values_map) + + def override_from_dict(self, values_dict): + """Override hyperparameter values, parsing new values from a dictionary. + + Args: + values_dict: Dictionary of name:value pairs. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values_dict` cannot be parsed. + """ + for name, value in values_dict.items(): + self.set_hparam(name, value) + return self + +## @deprecation.deprecated(None, 'Use `override_from_dict`.') + def set_from_map(self, values_map): + """DEPRECATED. Use override_from_dict.""" + return self.override_from_dict(values_dict=values_map) + + def set_model_structure(self, model_structure): + self._model_structure = model_structure + + def get_model_structure(self): + return self._model_structure + + def to_json(self, indent=None, separators=None, sort_keys=False): + """Serializes the hyperparameters into JSON. + + Args: + indent: If a non-negative integer, JSON array elements and object members + will be pretty-printed with that indent level. An indent level of 0, or + negative, will only insert newlines. `None` (the default) selects the + most compact representation. + separators: Optional `(item_separator, key_separator)` tuple. Default is + `(', ', ': ')`. + sort_keys: If `True`, the output dictionaries will be sorted by key. + + Returns: + A JSON string. + """ + return json.dumps( + self.values(), + indent=indent, + separators=separators, + sort_keys=sort_keys) + + def parse_json(self, values_json): + """Override hyperparameter values, parsing new values from a json object. + + Args: + values_json: String containing a json object of name:value pairs. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values_json` cannot be parsed. + """ + values_map = json.loads(values_json) + return self.override_from_dict(values_map) + + def values(self): + """Return the hyperparameter values as a Python dictionary. + + Returns: + A dictionary with hyperparameter names as keys. The values are the + hyperparameter values. + """ + return {n: getattr(self, n) for n in self._hparam_types.keys()} + + def get(self, key, default=None): + """Returns the value of `key` if it exists, else `default`.""" + if key in self._hparam_types: + # Ensure that default is compatible with the parameter type. + if default is not None: + param_type, is_param_list = self._hparam_types[key] + type_str = 'list<%s>' % param_type if is_param_list else str(param_type) + fail_msg = ("Hparam '%s' of type '%s' is incompatible with " + 'default=%s' % (key, type_str, default)) + + is_default_list = isinstance(default, list) + if is_param_list != is_default_list: + raise ValueError(fail_msg) + + try: + if is_default_list: + for value in default: + _cast_to_type_if_compatible(key, param_type, value) + else: + _cast_to_type_if_compatible(key, param_type, default) + except ValueError as e: + raise ValueError('%s. %s' % (fail_msg, e)) + + return getattr(self, key) + + return default + + def __contains__(self, key): + return key in self._hparam_types + + def __str__(self): + return str(sorted(self.values().items())) + + def __repr__(self): + return '%s(%s)' % (type(self).__name__, self.__str__()) + + @staticmethod + def _get_kind_name(param_type, is_list): + """Returns the field name given parameter type and is_list. + + Args: + param_type: Data type of the hparam. + is_list: Whether this is a list. + + Returns: + A string representation of the field name. + + Raises: + ValueError: If parameter type is not recognized. + """ + if issubclass(param_type, bool): + # This check must happen before issubclass(param_type, six.integer_types), + # since Python considers bool to be a subclass of int. + typename = 'bool' + elif issubclass(param_type, six.integer_types): + # Setting 'int' and 'long' types to be 'int64' to ensure the type is + # compatible with both Python2 and Python3. + typename = 'int64' + elif issubclass(param_type, (six.string_types, six.binary_type)): + # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is + # compatible with both Python2 and Python3. + typename = 'bytes' + elif issubclass(param_type, float): + typename = 'float' + else: + raise ValueError('Unsupported parameter type: %s' % str(param_type)) + + suffix = 'list' if is_list else 'value' + return '_'.join([typename, suffix]) + +## def to_proto(self, export_scope=None): # pylint: disable=unused-argument +## """Converts a `HParams` object to a `HParamDef` protocol buffer. +## +## Args: +## export_scope: Optional `string`. Name scope to remove. +## +## Returns: +## A `HParamDef` protocol buffer. +## """ +## hparam_proto = hparam_pb2.HParamDef() +## for name in self._hparam_types: +## # Parse the values. +## param_type, is_list = self._hparam_types.get(name, (None, None)) +## kind = HParams._get_kind_name(param_type, is_list) +## +## if is_list: +## if kind.startswith('bytes'): +## v_list = [compat.as_bytes(v) for v in getattr(self, name)] +## else: +## v_list = [v for v in getattr(self, name)] +## getattr(hparam_proto.hparam[name], kind).value.extend(v_list) +## else: +## v = getattr(self, name) +## if kind.startswith('bytes'): +## v = compat.as_bytes(getattr(self, name)) +## setattr(hparam_proto.hparam[name], kind, v) +## +## return hparam_proto + +## @staticmethod +## def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument +## return HParams(hparam_def=hparam_def) + + +## ops.register_proto_function( +## 'hparams', +## proto_type=hparam_pb2.HParamDef, +## to_proto=HParams.to_proto, +## from_proto=HParams.from_proto) diff --git a/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/tfcompat/readme.md b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/tfcompat/readme.md new file mode 100644 index 00000000..3d94e4c4 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/tfcompat/readme.md @@ -0,0 +1,8 @@ +Source: hparam.py copied from tensorflow v1.12.0. + +https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/training/python/training/hparam.py + +with the following: +wget https://github.com/tensorflow/tensorflow/raw/v1.12.0/tensorflow/contrib/training/python/training/hparam.py + +Once all other tensorflow dependencies of these file are removed, the class keeps its goal. Functions not available due to this process are not used in this project. diff --git a/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/upsample.py b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/upsample.py new file mode 100644 index 00000000..32b4ba15 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/upsample.py @@ -0,0 +1,111 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Upsampling. +""" +from __future__ import with_statement, print_function, absolute_import +import numpy as np +from mindspore import nn +from mindspore.ops import operations as P + + +class Resize(nn.Cell): + """ + Resize input Tensor + """ + + def __init__(self, x_scale, y_scale, mode="nearest"): + super(Resize, self).__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def construct(self, x): + _, _, h, w = x.shape + interpolate_op = P.ResizeNearestNeighbor((self.y_scale * h, self.x_scale * w)) + return interpolate_op(x) + + +def _get_activation(upsample_activation): + """get activation""" + nonlinear = getattr(nn, upsample_activation) + return nonlinear + + +class UpsampleNetwork(nn.Cell): + """UpsampleNetwork""" + def __init__(self, upsample_scales, mode="nearest", + freq_axis_kernel_size=1, cin_pad=0, cin_channels=80): + super(UpsampleNetwork, self).__init__() + self.expand_op = P.ExpandDims() + self.squeeze_op = P.Squeeze(1) + up_layers = [] + total_scale = np.prod(upsample_scales) + self.indent = cin_pad * total_scale + for scale in upsample_scales: + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + k_size = (freq_axis_kernel_size, scale * 2 + 1) + padding = (freq_axis_padding, freq_axis_padding, scale, scale) + stretch = Resize(scale, 1, mode) + conv = nn.Conv2d(1, 1, kernel_size=k_size, has_bias=False, pad_mode='pad', padding=padding) + up_layers.append(stretch) + up_layers.append(conv) + self.up_layers = nn.CellList(up_layers) + + def construct(self, c): + """ + + Args: + c (Tensor): Local conditioning feature + + Returns: + Tensor: Upsampling feature + + """ + # B x 1 x C x T + c = self.expand_op(c, 1) + for f in self.up_layers: + c = f(c) + # B x C x T + c = self.squeeze_op(c) + + return c + + +class ConvInUpsampleNetwork(nn.Cell): + """Upsample Network + + Args: + upsample_scales (list): Upsample_scales list. + upsample_activation (str): Upsample_activation. + mode (str): Resize mode, default is NearestNeighbor. + cin_channels (int): Local conditioning channels. + freq_axis_kernel_size (int): Freq-axis kernel_size for the convolution layers after resize. + + """ + + def __init__(self, upsample_scales, mode="nearest", + freq_axis_kernel_size=1, cin_pad=0, + cin_channels=80): + super(ConvInUpsampleNetwork, self).__init__() + ks = 2 * cin_pad + 1 + self.conv_in = nn.Conv1d(cin_channels, cin_channels, kernel_size=ks, has_bias=False, pad_mode='pad', padding=0) + self.upsample = UpsampleNetwork(upsample_scales, mode, freq_axis_kernel_size, cin_pad=0, + cin_channels=cin_channels) + + def construct(self, c): + c = self.conv_in(c) + c_up = self.upsample(c) + return c_up diff --git a/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/util.py b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/util.py new file mode 100644 index 00000000..4fea1d98 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/util.py @@ -0,0 +1,25 @@ +# coding: utf-8 +from __future__ import with_statement, print_function, absolute_import + + +def _assert_valid_input_type(s): + assert s == "mulaw-quantize" or s == "mulaw" or s == "raw" + + +def is_mulaw_quantize(s): + _assert_valid_input_type(s) + return s == "mulaw-quantize" + + +def is_mulaw(s): + _assert_valid_input_type(s) + return s == "mulaw" + + +def is_raw(s): + _assert_valid_input_type(s) + return s == "raw" + + +def is_scalar_input(s): + return is_raw(s) or is_mulaw(s) diff --git a/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/wavenet.py b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/wavenet.py new file mode 100644 index 00000000..50ca5b03 --- /dev/null +++ b/HuaWeiExperiment/wavenet/datasets/wavenet_vocoder/wavenet.py @@ -0,0 +1,335 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +WaveNet construction. +""" +from __future__ import with_statement, print_function, absolute_import + +import math +import numpy as np + +from mindspore import nn, Tensor +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from wavenet_vocoder import upsample +from .modules import Embedding +from .modules import Conv1d1x1 +from .modules import ResidualConv1dGLU +from .mixture import sample_from_discretized_mix_logistic +from .mixture import sample_from_mix_gaussian +from .mixture import sample_from_mix_onehotcategorical + + +class WaveNet(nn.Cell): + """ + WaveNet model definition. Only local condition is supported + + Args: + out_channels (int): Output channels. If input_type is mu-law quantized one-hot vecror, it should equal to the + quantize channels. Otherwise, it equals to num_mixtures x 3. Default: 256. + layers (int): Number of ResidualConv1dGLU layers + stacks (int): Number of dilation cycles + residual_channels (int): Residual input / output channels + gate_channels (int): Gated activation channels. + skip_out_channels (int): Skip connection channels. + kernel_size (int): Kernel size . + dropout (float): Dropout rate. + cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled. + gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled. + n_speakers (int): Number of speakers. This is used when global conditioning is enabled. + upsample_conditional_features (bool): Whether upsampling local conditioning features by resize_nearestneighbor + and conv or not. + scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise, quantized one-hot vector + is expected. + use_speaker_embedding (Bool): Use speaker embedding or Not. + + """ + + def __init__(self, out_channels=256, layers=20, stacks=2, + residual_channels=512, + gate_channels=512, + skip_out_channels=512, + kernel_size=3, dropout=1 - 0.95, + cin_channels=-1, gin_channels=-1, n_speakers=None, + upsample_conditional_features=False, + upsample_net="ConvInUpsampleNetwork", + upsample_params=None, + scalar_input=False, + use_speaker_embedding=False, + output_distribution="Logistic", + cin_pad=0, + ): + super(WaveNet, self).__init__() + self.transpose_op = P.Transpose() + self.softmax = P.Softmax(axis=1) + self.reshape_op = P.Reshape() + self.zeros_op = P.Zeros() + self.ones_op = P.Ones() + self.squeeze_op = P.Squeeze() + self.expandim_op = P.ExpandDims() + self.transpose_op = P.Transpose() + self.tile_op = P.Tile() + self.scalar_input = scalar_input + self.out_channels = out_channels + self.cin_channels = cin_channels + self.output_distribution = output_distribution + self.fack_data = P.Zeros() + assert layers % stacks == 0 + layers_per_stack = layers // stacks # 24 / 4 = 6 + if scalar_input: + self.first_conv = Conv1d1x1(1, residual_channels) + else: + self.first_conv = Conv1d1x1(out_channels, residual_channels) + + conv_layers = [] + for layer in range(layers): + dilation = 2 ** (layer % layers_per_stack) # 1, 2, 4, 8, 16, 32 + conv = ResidualConv1dGLU( + residual_channels, gate_channels, + kernel_size=kernel_size, + skip_out_channels=skip_out_channels, + bias=True, + dropout=dropout, + dilation=dilation, + cin_channels=cin_channels, + gin_channels=gin_channels) + conv_layers.append(conv) + self.conv_layers = nn.CellList(conv_layers) + self.last_conv_layers = nn.CellList([ + nn.ReLU(), + Conv1d1x1(skip_out_channels, skip_out_channels), + nn.ReLU(), + Conv1d1x1(skip_out_channels, out_channels)]) + + if gin_channels > 0 and use_speaker_embedding: + assert n_speakers is not None + self.embed_speakers = Embedding( + n_speakers, gin_channels, padding_idx=None, std=0.1) + else: + self.embed_speakers = None + + if upsample_conditional_features: + self.upsample_net = getattr(upsample, upsample_net)(**upsample_params) + else: + self.upsample_net = None + + self.factor = math.sqrt(1.0 / len(self.conv_layers)) # sqrt( 1 / 24) + + def _expand_global_features(self, batch_size, time_step, g_fp, is_expand=True): + """Expand global conditioning features to all time steps + + Args: + batch_size (int): Batch size. + time_step (int): Time length. + g_fp (Tensor): Global features, (B x C) or (B x C x 1). + is_expand (bool) : Expanded global conditioning features + + Returns: + Tensor: B x C x T or B x T x C or None + """ + if g_fp is None: + return None + if len(g_fp.shape) == 2: + g_fp = self.expandim_op(g_fp, -1) + else: + g_fp = g_fp + + if is_expand: + expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step)) + return expand_fp + expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step)) + expand_fp = self.transpose_op(expand_fp, (0, 2, 1)) + return expand_fp + + def construct(self, x, cond=None, g=None, softmax=False): + """ + + Args: + x (Tensor): One-hot encoded audio signal + c (Tensor): Local conditioning feature + g (Tensor): Global conditioning feature + softmax (bool): Whether use softmax or not + + Returns: + Tensor: Net output + + """ + + g = None + B, _, T = x.shape + if g is not None: + if self.embed_speakers is not None: + g = self.embed_speakers(self.reshape_op(g, (B, -1))) + g = self.transpose_op(g, (0, 2, 1)) + g_bct = self._expand_global_features(B, T, g, is_expand=True) # None + + if cond is not None and self.upsample_net is not None: + cond = self.upsample_net(cond) # [B, 128, 10240] + + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, hidden = f(x, cond, g_bct) # x=[B, 128, 10240], hidden=[B, 128, 10240] + skips += hidden + skips *= self.factor + + x = skips # x=[B, 128, 10240] + for f in self.last_conv_layers: + x = f(x) # x=[B, 2, 10240] + x = self.softmax(x) if softmax else x + + return x + + def relu_numpy(self, inX): + """numpy relu function""" + return np.maximum(0, inX) + + def softmax_numpy(self, x): + """ numpy softmax function """ + x -= np.max(x, axis=1, keepdims=True) + return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) + + def incremental_forward(self, initial_input=None, c_=None, g=None, + T=100, test_inputs=None, + tqdm=lambda x: x, softmax=True, quantize=True, + log_scale_min=-50.0, is_numpy=True): + """ + Incremental forward. Current output depends on last output. + + Args: + initial_input (Tensor): Initial input, the shape is B x C x 1 + c (Tensor): Local conditioning feature, the shape is B x C x T + g (Tensor): Global conditioning feature, the shape is B x C or B x C x 1 + T (int): decoding time step. + test_inputs: Teacher forcing inputs (for debugging) + tqdm (lamda): tqmd + softmax (bool): Whether use softmax or not + quantize (bool): Whether quantize softmax output in last step when decoding current step + log_scale_min (float): Log scale minimum value + + Returns: + Tensor: Predicted on-hot encoded samples or scalar vector depending on loss type + + """ + + self.clear_buffer() + B = 1 + + if test_inputs is not None: + if self.scalar_input: + if test_inputs.shape[1] == 1: + test_inputs = self.transpose_op(test_inputs, (0, 2, 1)) + else: + if test_inputs.shape[1] == self.out_channels: + test_inputs = self.transpose_op(test_inputs, (0, 2, 1)) + + B = test_inputs.shape[0] + if T is None: + T = test_inputs.shape[1] + else: + T = max(T, test_inputs.shape[1]) + T = int(T) + + # Global conditioning + if g is not None: + if self.embed_speakers is not None: + g = self.embed_speakers(self.reshape_op(g, (B, -1))) + g = self.transpose_op(g, (0, 2, 1)) + assert g.dim() == 3 + g_btc = self._expand_global_features(B, T, g, is_expand=False) + + # Local conditioning + if c_ is not None: + B = c_.shape[0] + if self.upsample_net is not None: + c_ = self.upsample_net(c_) + assert c_.shape[-1] == T + if c_.shape[-1] == T: + c_ = self.transpose_op(c_, (0, 2, 1)) + + outputs = [] + if initial_input is None: + if self.scalar_input: + initial_input = self.zeros_op((B, 1, 1), mstype.float32) + else: + initial_input = np.zeros((B, 1, self.out_channels), np.float32) + initial_input[:, :, 127] = 1 + initial_input = Tensor(initial_input) + else: + if initial_input.shape[1] == self.out_channels: + initial_input = self.transpose_op(initial_input, (0, 2, 1)) + + current_input = initial_input.asnumpy() + + for t in tqdm(range(T)): + if test_inputs is not None and t < test_inputs.shape[1]: + current_input = self.expandim_op(test_inputs[:, t, :], 1) + else: + if t > 0: + current_input = outputs[-1] + + # Conditioning features for single time step + ct = None if c_ is None else self.expandim_op(c_[:, t, :], 1) + gt = None if g is None else self.expandim_op(g_btc[:, t, :], 1) + + x = current_input + ct = ct.asnumpy() + x = self.first_conv.incremental_forward(x) + + skips = 0 + for f in self.conv_layers: + x, h = f.incremental_forward(x, ct, gt) + skips += h + skips *= self.factor + x = skips + + for f in self.last_conv_layers: + try: + x = f.incremental_forward(x) + except AttributeError: + x = self.relu_numpy(x) + + # Generate next input by sampling + if self.scalar_input: + if self.output_distribution == "Logistic": + x = sample_from_discretized_mix_logistic(x.reshape((B, -1, 1)), log_scale_min=log_scale_min) + + elif self.output_distribution == "Normal": + x = sample_from_mix_gaussian(x.reshape((B, -1, 1)), log_scale_min=log_scale_min) + else: + assert False + else: + x = self.softmax_numpy(np.reshape(x, (B, -1))) if softmax else np.reshape(x, (B, -1)) + if quantize: + x = sample_from_mix_onehotcategorical(x) + + outputs += [x] + # T x B x C + outputs = np.stack(outputs, 0) + # B x C x T + outputs = np.transpose(outputs, (1, 2, 0)) + self.clear_buffer() + return outputs + + def clear_buffer(self): + """clear buffer""" + self.first_conv.clear_buffer() + for f in self.conv_layers: + f.clear_buffer() + for f in self.last_conv_layers: + try: + f.clear_buffer() + except AttributeError: + pass diff --git a/HuaWeiExperiment/wavenet/egs/README.md b/HuaWeiExperiment/wavenet/egs/README.md new file mode 100644 index 00000000..5b9a21b3 --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/README.md @@ -0,0 +1,3 @@ +## Recipes + +Experimental https://github.com/espnet/espnet style recipes. diff --git a/HuaWeiExperiment/wavenet/egs/gaussian/conf/backup/gaussian_wavenet.json b/HuaWeiExperiment/wavenet/egs/gaussian/conf/backup/gaussian_wavenet.json new file mode 100644 index 00000000..db3f3e4b --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/gaussian/conf/backup/gaussian_wavenet.json @@ -0,0 +1,69 @@ +{ + "name": "wavenet_vocoder", + "input_type": "raw", + "quantize_channels": 65536, + "preprocess": "preemphasis", + "postprocess": "inv_preemphasis", + "global_gain_scale": 0.55, + "sample_rate": 22050, + "silence_threshold": 2, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "frame_shift_ms": null, + "win_length": 1024, + "win_length_ms": -1.0, + "window": "hann", + "highpass_cutoff": 70.0, + "output_distribution": "Normal", + "log_scale_min": -16.0, + "out_channels": 2, + "layers": 24, + "stacks": 4, + "residual_channels": 128, + "gate_channels": 256, + "skip_out_channels": 128, + "dropout": 0.0, + "kernel_size": 3, + "cin_channels": 80, + "cin_pad": 2, + "upsample_conditional_features": true, + "upsample_net": "ConvInUpsampleNetwork", + "upsample_params": { + "upsample_scales": [ + 4, + 4, + 4, + 4 + ] + }, + "gin_channels": -1, + "n_speakers": 7, + "pin_memory": true, + "num_workers": 2, + "batch_size": 8, + "optimizer": "Adam", + "optimizer_params": { + "lr": 0.001, + "eps": 1e-08, + "weight_decay": 0.0 + }, + "lr_schedule": "step_learning_rate_decay", + "lr_schedule_kwargs": { + "anneal_rate": 0.5, + "anneal_interval": 200000 + }, + "max_train_steps": 1000000, + "nepochs": 2000, + "clip_thresh": -1, + "max_time_sec": null, + "max_time_steps": 10240, + "exponential_moving_average": true, + "ema_decay": 0.9999, + "checkpoint_interval": 100000, + "train_eval_interval": 100000, + "test_eval_epoch_interval": 50, + "save_optimizer_state": true +} diff --git a/HuaWeiExperiment/wavenet/egs/gaussian/conf/backup/gaussian_wavenet_demo.json b/HuaWeiExperiment/wavenet/egs/gaussian/conf/backup/gaussian_wavenet_demo.json new file mode 100644 index 00000000..686e065b --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/gaussian/conf/backup/gaussian_wavenet_demo.json @@ -0,0 +1,69 @@ +{ + "name": "wavenet_vocoder", + "input_type": "raw", + "quantize_channels": 65536, + "preprocess": "", + "postprocess": "", + "global_gain_scale": 1.0, + "sample_rate": 22050, + "silence_threshold": 2, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "frame_shift_ms": null, + "win_length": 1024, + "win_length_ms": -1.0, + "window": "hann", + "highpass_cutoff": 70.0, + "output_distribution": "Normal", + "log_scale_min": -9.0, + "out_channels": 2, + "layers": 2, + "stacks": 1, + "residual_channels": 4, + "gate_channels": 4, + "skip_out_channels": 4, + "dropout": 0.0, + "kernel_size": 3, + "cin_channels": 80, + "cin_pad": 2, + "upsample_conditional_features": true, + "upsample_net": "ConvInUpsampleNetwork", + "upsample_params": { + "upsample_scales": [ + 4, + 4, + 4, + 4 + ] + }, + "gin_channels": -1, + "n_speakers": 7, + "pin_memory": true, + "num_workers": 2, + "batch_size": 1, + "optimizer": "Adam", + "optimizer_params": { + "lr": 0.001, + "eps": 1e-08, + "weight_decay": 0.0 + }, + "lr_schedule": "step_learning_rate_decay", + "lr_schedule_kwargs": { + "anneal_rate": 0.5, + "anneal_interval": 200000 + }, + "max_train_steps": 100, + "nepochs": 2000, + "clip_thresh": -1, + "max_time_sec": null, + "max_time_steps": 2560, + "exponential_moving_average": true, + "ema_decay": 0.9999, + "checkpoint_interval": 50, + "train_eval_interval": 50, + "test_eval_epoch_interval": 1, + "save_optimizer_state": true +} \ No newline at end of file diff --git a/HuaWeiExperiment/wavenet/egs/gaussian/conf/gaussian_wavenet.json b/HuaWeiExperiment/wavenet/egs/gaussian/conf/gaussian_wavenet.json new file mode 100644 index 00000000..b8e70649 --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/gaussian/conf/gaussian_wavenet.json @@ -0,0 +1,69 @@ +{ + "name": "wavenet_vocoder", + "input_type": "raw", + "quantize_channels": 65536, + "preprocess": "preemphasis", + "postprocess": "inv_preemphasis", + "global_gain_scale": 0.55, + "sample_rate": 22050, + "silence_threshold": 2, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "frame_shift_ms": null, + "win_length": 1024, + "win_length_ms": -1.0, + "window": "hann", + "highpass_cutoff": 70.0, + "output_distribution": "Normal", + "log_scale_min": -16.0, + "out_channels": 2, + "layers": 24, + "stacks": 4, + "residual_channels": 128, + "gate_channels": 256, + "skip_out_channels": 128, + "dropout": 0.0, + "kernel_size": 3, + "cin_channels": 80, + "cin_pad": 2, + "upsample_conditional_features": true, + "upsample_net": "ConvInUpsampleNetwork", + "upsample_params": { + "upsample_scales": [ + 4, + 4, + 4, + 4 + ] + }, + "gin_channels": -1, + "n_speakers": 7, + "pin_memory": true, + "num_workers": 2, + "batch_size": 8, + "optimizer": "Adam", + "optimizer_params": { + "lr": 0.001, + "eps": 1e-08, + "weight_decay": 0.0 + }, + "lr_schedule": "step_learning_rate_decay", + "lr_schedule_kwargs": { + "anneal_rate": 0.5, + "anneal_interval": 200000 + }, + "max_train_steps": 1000000, + "nepochs": 1, + "clip_thresh": -1, + "max_time_sec": null, + "max_time_steps": 10240, + "exponential_moving_average": true, + "ema_decay": 0.9999, + "checkpoint_interval": 100000, + "train_eval_interval": 100000, + "test_eval_epoch_interval": 50, + "save_optimizer_state": true +} diff --git a/HuaWeiExperiment/wavenet/egs/gaussian/conf/gaussian_wavenet_demo.json b/HuaWeiExperiment/wavenet/egs/gaussian/conf/gaussian_wavenet_demo.json new file mode 100644 index 00000000..408ae325 --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/gaussian/conf/gaussian_wavenet_demo.json @@ -0,0 +1,69 @@ +{ + "name": "wavenet_vocoder", + "input_type": "raw", + "quantize_channels": 65536, + "preprocess": "", + "postprocess": "", + "global_gain_scale": 1.0, + "sample_rate": 22050, + "silence_threshold": 2, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "frame_shift_ms": null, + "win_length": 1024, + "win_length_ms": -1.0, + "window": "hann", + "highpass_cutoff": 70.0, + "output_distribution": "Normal", + "log_scale_min": -9.0, + "out_channels": 2, + "layers": 2, + "stacks": 1, + "residual_channels": 4, + "gate_channels": 4, + "skip_out_channels": 4, + "dropout": 0.0, + "kernel_size": 3, + "cin_channels": 80, + "cin_pad": 2, + "upsample_conditional_features": true, + "upsample_net": "ConvInUpsampleNetwork", + "upsample_params": { + "upsample_scales": [ + 4, + 4, + 4, + 4 + ] + }, + "gin_channels": -1, + "n_speakers": 7, + "pin_memory": true, + "num_workers": 2, + "batch_size": 1, + "optimizer": "Adam", + "optimizer_params": { + "lr": 0.001, + "eps": 1e-08, + "weight_decay": 0.0 + }, + "lr_schedule": "step_learning_rate_decay", + "lr_schedule_kwargs": { + "anneal_rate": 0.5, + "anneal_interval": 200000 + }, + "max_train_steps": 100, + "nepochs": 20, + "clip_thresh": -1, + "max_time_sec": null, + "max_time_steps": 2560, + "exponential_moving_average": true, + "ema_decay": 0.9999, + "checkpoint_interval": 50, + "train_eval_interval": 50, + "test_eval_epoch_interval": 1, + "save_optimizer_state": true +} \ No newline at end of file diff --git a/HuaWeiExperiment/wavenet/egs/gaussian/run.sh b/HuaWeiExperiment/wavenet/egs/gaussian/run.sh new file mode 100644 index 00000000..118cba82 --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/gaussian/run.sh @@ -0,0 +1,124 @@ +#!/bin/bash + +script_dir=$(cd $(dirname ${BASH_SOURCE:-$0}); pwd) +VOC_DIR=$script_dir/../../ + +# Directory that contains all wav files +# **CHANGE** this to your database path +# 这里需要改变为你的数据集的地址 +db_root=~/data/LJSpeech-1.1/wavs/ + +spk="lj" +dumpdir=dump + +# train/dev/eval split +dev_size=10 +eval_size=10 +# Maximum size of train/dev/eval data (in hours). +# set small value (e.g. 0.2) for testing +limit=1000000 + +# waveform global gain normalization scale +global_gain_scale=0.55 + +stage=0 +stop_stage=0 + +# Hyper parameters (.json) +# **CHANGE** here to your own hparams +hparams=conf/gaussian_wavenet_demo.json + +# Batch size at inference time. +inference_batch_size=32 +# Leave empty to use latest checkpoint +eval_checkpoint= +# Max number of utts. for evaluation( for debugging) +eval_max_num_utt=1000000 + +# exp tag +tag="" # tag for managing experiments. + +. $VOC_DIR/utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_set="train_no_dev" +dev_set="dev" +eval_set="eval" +datasets=($train_set $dev_set $eval_set) + +# exp name +if [ -z ${tag} ]; then + expname=${spk}_${train_set}_$(basename ${hparams%.*}) +else + expname=${spk}_${train_set}_${tag} +fi +expdir=exp/$expname + +feat_typ="logmelspectrogram" + +# Output directories +data_root=data/$spk # train/dev/eval splitted data +dump_org_dir=$dumpdir/$spk/$feat_typ/org # extracted features (pair of ) +dump_norm_dir=$dumpdir/$spk/$feat_typ/norm # extracted features (pair of ) + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: train/dev/eval split" + if [ -z $db_root ]; then + echo "ERROR: DB ROOT must be specified for train/dev/eval splitting." + echo " Use option --db-root \${path_contains_wav_files}" + exit 1 + fi + python $VOC_DIR/mksubset.py $db_root $data_root \ + --train-dev-test-split --dev-size $dev_size --test-size $eval_size \ + --limit=$limit +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "stage 1: Feature Generation" + for s in ${datasets[@]}; + do + python $VOC_DIR/preprocess.py wavallin $data_root/$s ${dump_org_dir}/$s \ + --hparams="global_gain_scale=${global_gain_scale}" --preset=$hparams + done + + # Compute mean-var normalization stats + find $dump_org_dir/$train_set -type f -name "*feats.npy" > train_list.txt + python $VOC_DIR/compute-meanvar-stats.py train_list.txt $dump_org_dir/meanvar.joblib + rm -f train_list.txt + + # Apply normalization + for s in ${datasets[@]}; + do + python $VOC_DIR/preprocess_normalize.py ${dump_org_dir}/$s $dump_norm_dir/$s \ + $dump_org_dir/meanvar.joblib + done + cp -f $dump_org_dir/meanvar.joblib ${dump_norm_dir}/meanvar.joblib +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "stage 2: WaveNet training" + python $VOC_DIR/train.py --dump-root $dump_norm_dir --preset $hparams \ + --checkpoint-dir=$expdir \ + --log-event-path=tensorboard/${expname} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: Synthesis waveform from WaveNet" + if [ -z $eval_checkpoint ]; then + eval_checkpoint=$expdir/checkpoint_latest.pth + fi + name=$(basename $eval_checkpoint) + name=${name/.pth/} + for s in $dev_set $eval_set; + do + dst_dir=$expdir/generated/$name/$s + python $VOC_DIR/evaluate.py $dump_norm_dir/$s $eval_checkpoint $dst_dir \ + --preset $hparams --hparams="batch_size=$inference_batch_size" \ + --num-utterances=$eval_max_num_utt + done +fi diff --git a/HuaWeiExperiment/wavenet/egs/mol/conf/mol_wavenet.json b/HuaWeiExperiment/wavenet/egs/mol/conf/mol_wavenet.json new file mode 100644 index 00000000..ed43dbda --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/mol/conf/mol_wavenet.json @@ -0,0 +1,69 @@ +{ + "name": "wavenet_vocoder", + "input_type": "raw", + "quantize_channels": 65536, + "preprocess": "preemphasis", + "postprocess": "inv_preemphasis", + "global_gain_scale": 0.55, + "sample_rate": 22050, + "silence_threshold": 2, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "frame_shift_ms": null, + "win_length": 1024, + "win_length_ms": -1.0, + "window": "hann", + "highpass_cutoff": 70.0, + "output_distribution": "Logistic", + "log_scale_min": -16.0, + "out_channels": 30, + "layers": 24, + "stacks": 4, + "residual_channels": 128, + "gate_channels": 256, + "skip_out_channels": 128, + "dropout": 0.0, + "kernel_size": 3, + "cin_channels": 80, + "cin_pad": 2, + "upsample_conditional_features": true, + "upsample_net": "ConvInUpsampleNetwork", + "upsample_params": { + "upsample_scales": [ + 4, + 4, + 4, + 4 + ] + }, + "gin_channels": -1, + "n_speakers": 7, + "pin_memory": true, + "num_workers": 2, + "batch_size": 8, + "optimizer": "Adam", + "optimizer_params": { + "lr": 0.001, + "eps": 1e-08, + "weight_decay": 0.0 + }, + "lr_schedule": "step_learning_rate_decay", + "lr_schedule_kwargs": { + "anneal_rate": 0.5, + "anneal_interval": 200000 + }, + "max_train_steps": 1000000, + "nepochs": 2000, + "clip_thresh": -1, + "max_time_sec": null, + "max_time_steps": 10240, + "exponential_moving_average": true, + "ema_decay": 0.9999, + "checkpoint_interval": 100000, + "train_eval_interval": 100000, + "test_eval_epoch_interval": 50, + "save_optimizer_state": true +} diff --git a/HuaWeiExperiment/wavenet/egs/mol/conf/mol_wavenet_demo.json b/HuaWeiExperiment/wavenet/egs/mol/conf/mol_wavenet_demo.json new file mode 100644 index 00000000..109e615e --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/mol/conf/mol_wavenet_demo.json @@ -0,0 +1,69 @@ +{ + "name": "wavenet_vocoder", + "input_type": "raw", + "quantize_channels": 65536, + "preprocess": "", + "postprocess": "", + "global_gain_scale": 1.0, + "sample_rate": 22050, + "silence_threshold": 2, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "frame_shift_ms": null, + "win_length": 1024, + "win_length_ms": -1.0, + "window": "hann", + "highpass_cutoff": 70.0, + "output_distribution": "Logistic", + "log_scale_min": -9.0, + "out_channels": 30, + "layers": 2, + "stacks": 1, + "residual_channels": 4, + "gate_channels": 4, + "skip_out_channels": 4, + "dropout": 0.0, + "kernel_size": 3, + "cin_channels": 80, + "cin_pad": 2, + "upsample_conditional_features": true, + "upsample_net": "ConvInUpsampleNetwork", + "upsample_params": { + "upsample_scales": [ + 4, + 4, + 4, + 4 + ] + }, + "gin_channels": -1, + "n_speakers": 7, + "pin_memory": true, + "num_workers": 2, + "batch_size": 1, + "optimizer": "Adam", + "optimizer_params": { + "lr": 0.001, + "eps": 1e-08, + "weight_decay": 0.0 + }, + "lr_schedule": "step_learning_rate_decay", + "lr_schedule_kwargs": { + "anneal_rate": 0.5, + "anneal_interval": 200000 + }, + "max_train_steps": 100, + "nepochs": 2000, + "clip_thresh": -1, + "max_time_sec": null, + "max_time_steps": 2560, + "exponential_moving_average": true, + "ema_decay": 0.9999, + "checkpoint_interval": 50, + "train_eval_interval": 50, + "test_eval_epoch_interval": 1, + "save_optimizer_state": true +} \ No newline at end of file diff --git a/HuaWeiExperiment/wavenet/egs/mol/run.sh b/HuaWeiExperiment/wavenet/egs/mol/run.sh new file mode 100644 index 00000000..6fbf8cf4 --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/mol/run.sh @@ -0,0 +1,123 @@ +#!/bin/bash + +script_dir=$(cd $(dirname ${BASH_SOURCE:-$0}); pwd) +VOC_DIR=$script_dir/../../ + +# Directory that contains all wav files +# **CHANGE** this to your database path +db_root=~/data/LJSpeech-1.1/wavs/ + +spk="lj" +dumpdir=dump + +# train/dev/eval split +dev_size=10 +eval_size=10 +# Maximum size of train/dev/eval data (in hours). +# set small value (e.g. 0.2) for testing +limit=1000000 + +# waveform global gain normalization scale +global_gain_scale=0.55 + +stage=0 +stop_stage=0 + +# Hyper parameters (.json) +# **CHANGE** here to your own hparams +hparams=conf/mol_wavenet_demo.json + +# Batch size at inference time. +inference_batch_size=32 +# Leave empty to use latest checkpoint +eval_checkpoint= +# Max number of utts. for evaluation( for debugging) +eval_max_num_utt=1000000 + +# exp tag +tag="" # tag for managing experiments. + +. $VOC_DIR/utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_set="train_no_dev" +dev_set="dev" +eval_set="eval" +datasets=($train_set $dev_set $eval_set) + +# exp name +if [ -z ${tag} ]; then + expname=${spk}_${train_set}_$(basename ${hparams%.*}) +else + expname=${spk}_${train_set}_${tag} +fi +expdir=exp/$expname + +feat_typ="logmelspectrogram" + +# Output directories +data_root=data/$spk # train/dev/eval splitted data +dump_org_dir=$dumpdir/$spk/$feat_typ/org # extracted features (pair of ) +dump_norm_dir=$dumpdir/$spk/$feat_typ/norm # extracted features (pair of ) + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: train/dev/eval split" + if [ -z $db_root ]; then + echo "ERROR: DB ROOT must be specified for train/dev/eval splitting." + echo " Use option --db-root \${path_contains_wav_files}" + exit 1 + fi + python $VOC_DIR/mksubset.py $db_root $data_root \ + --train-dev-test-split --dev-size $dev_size --test-size $eval_size \ + --limit=$limit +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "stage 1: Feature Generation" + for s in ${datasets[@]}; + do + python $VOC_DIR/preprocess.py wavallin $data_root/$s ${dump_org_dir}/$s \ + --hparams="global_gain_scale=${global_gain_scale}" --preset=$hparams + done + + # Compute mean-var normalization stats + find $dump_org_dir/$train_set -type f -name "*feats.npy" > train_list.txt + python $VOC_DIR/compute-meanvar-stats.py train_list.txt $dump_org_dir/meanvar.joblib + rm -f train_list.txt + + # Apply normalization + for s in ${datasets[@]}; + do + python $VOC_DIR/preprocess_normalize.py ${dump_org_dir}/$s $dump_norm_dir/$s \ + $dump_org_dir/meanvar.joblib + done + cp -f $dump_org_dir/meanvar.joblib ${dump_norm_dir}/meanvar.joblib +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "stage 2: WaveNet training" + python $VOC_DIR/train.py --dump-root $dump_norm_dir --preset $hparams \ + --checkpoint-dir=$expdir \ + --log-event-path=tensorboard/${expname} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: Synthesis waveform from WaveNet" + if [ -z $eval_checkpoint ]; then + eval_checkpoint=$expdir/checkpoint_latest.pth + fi + name=$(basename $eval_checkpoint) + name=${name/.pth/} + for s in $dev_set $eval_set; + do + dst_dir=$expdir/generated/$name/$s + python $VOC_DIR/evaluate.py $dump_norm_dir/$s $eval_checkpoint $dst_dir \ + --preset $hparams --hparams="batch_size=$inference_batch_size" \ + --num-utterances=$eval_max_num_utt + done +fi diff --git a/HuaWeiExperiment/wavenet/egs/mulaw256/conf/mulaw256_wavenet.json b/HuaWeiExperiment/wavenet/egs/mulaw256/conf/mulaw256_wavenet.json new file mode 100644 index 00000000..db797ca9 --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/mulaw256/conf/mulaw256_wavenet.json @@ -0,0 +1,69 @@ +{ + "name": "wavenet_vocoder", + "input_type": "mulaw-quantize", + "quantize_channels": 256, + "preprocess": "preemphasis", + "postprocess": "inv_preemphasis", + "global_gain_scale": 0.55, + "sample_rate": 22050, + "silence_threshold": 2, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "frame_shift_ms": null, + "win_length": 1024, + "win_length_ms": -1.0, + "window": "hann", + "highpass_cutoff": 70.0, + "output_distribution": "Logistic", + "log_scale_min": -9.0, + "out_channels": 256, + "layers": 30, + "stacks": 3, + "residual_channels": 128, + "gate_channels": 256, + "skip_out_channels": 128, + "dropout": 0.0, + "kernel_size": 3, + "cin_channels": 80, + "cin_pad": 2, + "upsample_conditional_features": true, + "upsample_net": "ConvInUpsampleNetwork", + "upsample_params": { + "upsample_scales": [ + 4, + 4, + 4, + 4 + ] + }, + "gin_channels": -1, + "n_speakers": 7, + "pin_memory": true, + "num_workers": 2, + "batch_size": 8, + "optimizer": "Adam", + "optimizer_params": { + "lr": 0.001, + "eps": 1e-08, + "weight_decay": 0.0 + }, + "lr_schedule": "step_learning_rate_decay", + "lr_schedule_kwargs": { + "anneal_rate": 0.5, + "anneal_interval": 200000 + }, + "max_train_steps": 500000, + "nepochs": 2000, + "clip_thresh": -1, + "max_time_sec": null, + "max_time_steps": 10240, + "exponential_moving_average": true, + "ema_decay": 0.9999, + "checkpoint_interval": 100000, + "train_eval_interval": 100000, + "test_eval_epoch_interval": 50, + "save_optimizer_state": true +} \ No newline at end of file diff --git a/HuaWeiExperiment/wavenet/egs/mulaw256/conf/mulaw256_wavenet_demo.json b/HuaWeiExperiment/wavenet/egs/mulaw256/conf/mulaw256_wavenet_demo.json new file mode 100644 index 00000000..d36395d9 --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/mulaw256/conf/mulaw256_wavenet_demo.json @@ -0,0 +1,69 @@ +{ + "name": "wavenet_vocoder", + "input_type": "mulaw-quantize", + "quantize_channels": 256, + "preprocess": "", + "postprocess": "", + "global_gain_scale": 1.0, + "sample_rate": 22050, + "silence_threshold": 2, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "frame_shift_ms": null, + "win_length": 1024, + "win_length_ms": -1.0, + "window": "hann", + "highpass_cutoff": 70.0, + "output_distribution": "Logistic", + "log_scale_min": -9.0, + "out_channels": 256, + "layers": 2, + "stacks": 1, + "residual_channels": 4, + "gate_channels": 4, + "skip_out_channels": 4, + "dropout": 0.0, + "kernel_size": 3, + "cin_channels": 80, + "cin_pad": 2, + "upsample_conditional_features": true, + "upsample_net": "ConvInUpsampleNetwork", + "upsample_params": { + "upsample_scales": [ + 4, + 4, + 4, + 4 + ] + }, + "gin_channels": -1, + "n_speakers": 7, + "pin_memory": true, + "num_workers": 2, + "batch_size": 1, + "optimizer": "Adam", + "optimizer_params": { + "lr": 0.001, + "eps": 1e-08, + "weight_decay": 0.0 + }, + "lr_schedule": "step_learning_rate_decay", + "lr_schedule_kwargs": { + "anneal_rate": 0.5, + "anneal_interval": 200000 + }, + "max_train_steps": 100, + "nepochs": 2000, + "clip_thresh": -1, + "max_time_sec": null, + "max_time_steps": 2560, + "exponential_moving_average": true, + "ema_decay": 0.9999, + "checkpoint_interval": 50, + "train_eval_interval": 50, + "test_eval_epoch_interval": 1, + "save_optimizer_state": true +} \ No newline at end of file diff --git a/HuaWeiExperiment/wavenet/egs/mulaw256/run.sh b/HuaWeiExperiment/wavenet/egs/mulaw256/run.sh new file mode 100644 index 00000000..79171cd3 --- /dev/null +++ b/HuaWeiExperiment/wavenet/egs/mulaw256/run.sh @@ -0,0 +1,123 @@ +#!/bin/bash + +script_dir=$(cd $(dirname ${BASH_SOURCE:-$0}); pwd) +VOC_DIR=$script_dir/../../ + +# Directory that contains all wav files +# **CHANGE** this to your database path +db_root=~/data/LJSpeech-1.1/wavs/ + +spk="lj" +dumpdir=dump + +# train/dev/eval split +dev_size=10 +eval_size=10 +# Maximum size of train/dev/eval data (in hours). +# set small value (e.g. 0.2) for testing +limit=1000000 + +# waveform global gain normalization scale +global_gain_scale=0.55 + +stage=0 +stop_stage=0 + +# Hyper parameters (.json) +# **CHANGE** here to your own hparams +hparams=conf/mulaw256_wavenet_demo.json + +# Batch size at inference time. +inference_batch_size=32 +# Leave empty to use latest checkpoint +eval_checkpoint= +# Max number of utts. for evaluation( for debugging) +eval_max_num_utt=1000000 + +# exp tag +tag="" # tag for managing experiments. + +. $VOC_DIR/utils/parse_options.sh || exit 1; + +# Set bash to 'debug' mode, it will exit on : +# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', +set -e +set -u +set -o pipefail + +train_set="train_no_dev" +dev_set="dev" +eval_set="eval" +datasets=($train_set $dev_set $eval_set) + +# exp name +if [ -z ${tag} ]; then + expname=${spk}_${train_set}_$(basename ${hparams%.*}) +else + expname=${spk}_${train_set}_${tag} +fi +expdir=exp/$expname + +feat_typ="logmelspectrogram" + +# Output directories +data_root=data/$spk # train/dev/eval splitted data +dump_org_dir=$dumpdir/$spk/$feat_typ/org # extracted features (pair of ) +dump_norm_dir=$dumpdir/$spk/$feat_typ/norm # extracted features (pair of ) + +if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then + echo "stage 0: train/dev/eval split" + if [ -z $db_root ]; then + echo "ERROR: DB ROOT must be specified for train/dev/eval splitting." + echo " Use option --db-root \${path_contains_wav_files}" + exit 1 + fi + python $VOC_DIR/mksubset.py $db_root $data_root \ + --train-dev-test-split --dev-size $dev_size --test-size $eval_size \ + --limit=$limit +fi + +if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then + echo "stage 1: Feature Generation" + for s in ${datasets[@]}; + do + python $VOC_DIR/preprocess.py wavallin $data_root/$s ${dump_org_dir}/$s \ + --hparams="global_gain_scale=${global_gain_scale}" --preset=$hparams + done + + # Compute mean-var normalization stats + find $dump_org_dir/$train_set -type f -name "*feats.npy" > train_list.txt + python $VOC_DIR/compute-meanvar-stats.py train_list.txt $dump_org_dir/meanvar.joblib + rm -f train_list.txt + + # Apply normalization + for s in ${datasets[@]}; + do + python $VOC_DIR/preprocess_normalize.py ${dump_org_dir}/$s $dump_norm_dir/$s \ + $dump_org_dir/meanvar.joblib + done + cp -f $dump_org_dir/meanvar.joblib ${dump_norm_dir}/meanvar.joblib +fi + +if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then + echo "stage 2: WaveNet training" + python $VOC_DIR/train.py --dump-root $dump_norm_dir --preset $hparams \ + --checkpoint-dir=$expdir \ + --log-event-path=tensorboard/${expname} +fi + +if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then + echo "stage 3: Synthesis waveform from WaveNet" + if [ -z $eval_checkpoint ]; then + eval_checkpoint=$expdir/checkpoint_latest.pth + fi + name=$(basename $eval_checkpoint) + name=${name/.pth/} + for s in $dev_set $eval_set; + do + dst_dir=$expdir/generated/$name/$s + python $VOC_DIR/evaluate.py $dump_norm_dir/$s $eval_checkpoint $dst_dir \ + --preset $hparams --hparams="batch_size=$inference_batch_size" \ + --num-utterances=$eval_max_num_utt + done +fi diff --git a/HuaWeiExperiment/wavenet/eval.log b/HuaWeiExperiment/wavenet/eval.log new file mode 100644 index 00000000..cd0be717 --- /dev/null +++ b/HuaWeiExperiment/wavenet/eval.log @@ -0,0 +1 @@ +E:\anaconda3\envs\HuaWeiExperiment\python.exe: can't open file 'E:\python\pythonProjects\HuaWeiExperiment\evaluate.py': [Errno 2] No such file or directory diff --git a/HuaWeiExperiment/wavenet/evaluate.py b/HuaWeiExperiment/wavenet/evaluate.py new file mode 100644 index 00000000..6b869d95 --- /dev/null +++ b/HuaWeiExperiment/wavenet/evaluate.py @@ -0,0 +1,270 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Evaluation. +""" +import os +from os.path import join +import argparse +import glob +import math +import audio +import numpy as np +from scipy.io import wavfile +from hparams import hparams, hparams_debug_string +from tqdm import tqdm +from mindspore import context, Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net +import mindspore.dataset.engine as de +from nnmnkwii import preprocessing as P +from nnmnkwii.datasets import FileSourceDataset +from wavenet_vocoder import WaveNet +from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_scalar_input +from src.dataset import RawAudioDataSource, MelSpecDataSource, DualDataset + +parser = argparse.ArgumentParser(description='TTS training') +parser.add_argument('--data_path', type=str, required=True, default='', + help='Directory contains preprocessed features.') +parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).') +parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path') +parser.add_argument('--is_numpy', action="store_true", default=False, help='Using numpy for inference or not') +parser.add_argument('--output_path', type=str, default='./out_wave/', help='Path to save generated audios') +parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU', 'CPU'), + help='run platform, support Ascend, GPU and CPU. Default: GPU') +parser.add_argument('--speaker_id', type=str, default='', + help=' Use specific speaker of data in case for multi-speaker datasets.') +args = parser.parse_args() + + +def get_data_loader(hparam, data_dir): + """ + test data loader + """ + wav_paths = glob.glob(os.path.join(data_dir, "*-wave.npy")) + if wav_paths: + X = FileSourceDataset(RawAudioDataSource(data_dir, + hop_size=audio.get_hop_size(), + max_steps=None, cin_pad=hparam.cin_pad)) + else: + X = None + C = FileSourceDataset(MelSpecDataSource(data_dir, + hop_size=audio.get_hop_size(), + max_steps=None, cin_pad=hparam.cin_pad)) + + length_x = np.array(C.file_data_source.lengths) + if C[0].shape[-1] != hparam.cin_channels: + raise RuntimeError("Invalid cin_channnels {}. Expected to be {}.".format(hparam.cin_channels, C[0].shape[-1])) + + dataset = DualDataset(X, C, length_x, batch_size=hparam.batch_size, hparams=hparam) + + data_loader = de.GeneratorDataset(dataset, ["x_batch", "y_batch", "c_batch", "g_batch", "input_lengths", "mask"]) + + return data_loader, dataset + + +def batch_wavegen(hparam, net, c_input=None, g_input=None, tqdm_=None, is_numpy=True): + """ + generate audio + """ + assert c_input is not None + B = c_input.shape[0] + net.set_train(False) + + n_frames = c_input.shape[-1] + y_hat_list = [] + chunk_wise = 16 + 2 * hparam.cin_pad + for k in range(math.ceil(n_frames / chunk_wise)): + start = k * chunk_wise + end = min(n_frames, (k + 1) * chunk_wise) + lens = (end - start - hparam.cin_pad * 2) * audio.get_hop_size() + + y_hat = net.incremental_forward(c_=c_input[:, :, start: end], g=g_input, T=lens, tqdm=tqdm_, softmax=True, + quantize=True, + log_scale_min=hparam.log_scale_min, is_numpy=is_numpy) + y_hat_list.append(y_hat) + + y_hat = np.concatenate(y_hat_list, axis=2) + + if is_mulaw_quantize(hparam.input_type): + # needs to be float since mulaw_inv returns in range of [-1, 1] + y_hat = np.reshape(np.argmax(y_hat, 1), (B, -1)) + y_hat = y_hat.astype(np.float32) + for k in range(B): + y_hat[k] = P.inv_mulaw_quantize(y_hat[k], hparam.quantize_channels - 1) + elif is_mulaw(hparam.input_type): + y_hat = np.reshape(y_hat, (B, -1)) + for k in range(B): + y_hat[k] = P.inv_mulaw(y_hat[k], hparam.quantize_channels - 1) + else: + y_hat = np.reshape(y_hat, (B, -1)) + + if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]: + for k in range(B): + y_hat[k] = getattr(audio, hparam.postprocess)(y_hat[k]) + + if hparam.global_gain_scale > 0: + for k in range(B): + y_hat[k] /= hparam.global_gain_scale + + return y_hat + + +def to_int16(x_): + """ + convert datatype to int16 + """ + if x_.dtype == np.int16: + return x_ + assert x_.dtype == np.float32 + assert x_.min() >= -1 and x_.max() <= 1.0 + return (x_ * 32767).astype(np.int16) + + +def get_reference_file(hparam, dataset_source, idx): + """ + get reference files + """ + reference_files = [] + reference_feats = [] + for _ in range(hparam.batch_size): + if hasattr(dataset_source, "X"): + reference_files.append(dataset_source.X.collected_files[idx][0]) + else: + pass + if hasattr(dataset_source, "Mel"): + reference_feats.append(dataset_source.Mel.collected_files[idx][0]) + else: + reference_feats.append(dataset_source.collected_files[idx][0]) + idx += 1 + return reference_files, reference_feats, idx + + +def get_saved_audio_name(has_ref_file_, ref_file, ref_feat, g_fp): + """get path to save reference audio""" + if has_ref_file_: + target_audio_path = ref_file + name = os.path.splitext(os.path.basename(target_audio_path))[0].replace("-wave", "") + else: + target_feat_path = ref_feat + name = os.path.splitext(os.path.basename(target_feat_path))[0].replace("-feats", "") + # Paths + if g_fp is None: + dst_wav_path_ = join(args.output_path, "{}_gen.wav".format(name)) + target_wav_path_ = join(args.output_path, "{}_ref.wav".format(name)) + else: + dst_wav_path_ = join(args.output_path, "speaker{}_{}_gen.wav".format(g, name)) + target_wav_path_ = join(args.output_path, "speaker{}_{}_ref.wav".format(g, name)) + return dst_wav_path_, target_wav_path_ + + +def save_ref_audio(hparam, ref, length, target_wav_path_): + """ + save reference audio + """ + if is_mulaw_quantize(hparam.input_type): + ref = np.reshape(np.argmax(ref, 0), (-1))[:length] + ref = ref.astype(np.float32) + else: + ref = np.reshape(ref, (-1))[:length] + + if is_mulaw_quantize(hparam.input_type): + ref = P.inv_mulaw_quantize(ref, hparam.quantize_channels - 1) + elif is_mulaw(hparam.input_type): + ref = P.inv_mulaw(ref, hparam.quantize_channels - 1) + if hparam.postprocess is not None and hparam.postprocess not in ["", "none"]: + ref = getattr(audio, hparam.postprocess)(ref) + if hparam.global_gain_scale > 0: + ref /= hparam.global_gain_scale + + ref = np.clip(ref, -1.0, 1.0) + + wavfile.write(target_wav_path_, hparam.sample_rate, to_int16(ref)) + + +if __name__ == '__main__': + + if args.platform != 'Ascend': + context.set_context(mode=0, device_target=args.platform, save_graphs=False) + else: + device_id = int(os.getenv("DEVICE_ID")) + context.set_context(mode=1, device_target=args.platform, device_id=device_id) + + speaker_id = int(args.speaker_id) if args.speaker_id != '' else None + if args.preset is not None: + with open(args.preset) as f: + hparams.parse_json(f.read()) + + assert hparams.name == "wavenet_vocoder" + print(hparams_debug_string()) + + fs = hparams.sample_rate + hparams.batch_size = 10 + hparams.max_time_sec = None + hparams.max_time_steps = None + data_loaders, source_dataset = get_data_loader(hparam=hparams, data_dir=args.data_path) + + upsample_params = hparams.upsample_params + upsample_params["cin_channels"] = hparams.cin_channels + upsample_params["cin_pad"] = hparams.cin_pad + model = WaveNet( + out_channels=hparams.out_channels, + layers=hparams.layers, + stacks=hparams.stacks, + residual_channels=hparams.residual_channels, + gate_channels=hparams.gate_channels, + skip_out_channels=hparams.skip_out_channels, + cin_channels=hparams.cin_channels, + gin_channels=hparams.gin_channels, + n_speakers=hparams.n_speakers, + dropout=hparams.dropout, + kernel_size=hparams.kernel_size, + cin_pad=hparams.cin_pad, + upsample_conditional_features=hparams.upsample_conditional_features, + upsample_params=upsample_params, + scalar_input=is_scalar_input(hparams.input_type), + output_distribution=hparams.output_distribution, + ) + + param_dict = load_checkpoint(args.pretrain_ckpt) + load_param_into_net(model, param_dict) + print('Successfully loading the pre-trained model') + + os.makedirs(args.output_path, exist_ok=True) + cin_pad = hparams.cin_pad + + file_idx = 0 + for data in data_loaders.create_dict_iterator(): + x, y, c, g, input_lengths = data['x_batch'], data['y_batch'], data['c_batch'], data['g_batch'], data[ + 'input_lengths'] + if cin_pad > 0: + c = c.asnumpy() + c = np.pad(c, pad_width=(cin_pad, cin_pad), mode="edge") + c = Tensor(c) + + ref_files, ref_feats, file_idx = get_reference_file(hparams, source_dataset, file_idx) + # Generate + y_hats = batch_wavegen(hparams, model, data['c_batch'], tqdm_=tqdm, is_numpy=args.is_numpy) + x = x.asnumpy() + input_lengths = input_lengths.asnumpy() + # Save each utt. + has_ref_file = bool(ref_files) + for i, (ref_, gen_, length_) in enumerate(zip(x, y_hats, input_lengths)): + dst_wav_path, target_wav_path = get_saved_audio_name(has_ref_file_=has_ref_file, ref_file=ref_files[i], + ref_feat=ref_feats[i], g_fp=g) + save_ref_audio(hparams, ref_, length_, target_wav_path) + + gen = gen_[:length_] + gen = np.clip(gen, -1.0, 1.0) + wavfile.write(dst_wav_path, hparams.sample_rate, to_int16(gen)) diff --git a/HuaWeiExperiment/wavenet/export.py b/HuaWeiExperiment/wavenet/export.py new file mode 100644 index 00000000..a76cef90 --- /dev/null +++ b/HuaWeiExperiment/wavenet/export.py @@ -0,0 +1,106 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Export mindir. +""" +import json +from os.path import join +import argparse +from warnings import warn +from hparams import hparams, hparams_debug_string +from mindspore import context, Tensor +from mindspore.train.serialization import load_checkpoint, load_param_into_net, export +from wavenet_vocoder import WaveNet +from wavenet_vocoder.util import is_mulaw_quantize, is_scalar_input +import numpy as np +from src.loss import PredictNet + +parser = argparse.ArgumentParser(description='TTS training') +parser.add_argument('--preset', type=str, default='', help='Path of preset parameters (json).') +parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_test', + help='Directory where to save model checkpoints [default: checkpoints].') +parser.add_argument('--speaker_id', type=str, default='', + help=' Use specific speaker of data in case for multi-speaker datasets.') +parser.add_argument('--pretrain_ckpt', type=str, default='', help='Pretrained checkpoint path') +parser.add_argument('--platform', type=str, default='GPU', help='Running device') +args = parser.parse_args() + +if __name__ == '__main__': + + target = args.platform + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + + speaker_id = int(args.speaker_id) if args.speaker_id != '' else None + if args.preset is not None: + with open(args.preset) as f: + hparams.parse_json(f.read()) + + assert hparams.name == "wavenet_vocoder" + print(hparams_debug_string()) + + fs = hparams.sample_rate + output_json_path = join(args.checkpoint_dir, "hparams.json") + with open(output_json_path, "w") as f: + json.dump(hparams.values(), f, indent=2) + + if is_mulaw_quantize(hparams.input_type): + if hparams.out_channels != hparams.quantize_channels: + raise RuntimeError( + "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'") + if hparams.upsample_conditional_features and hparams.cin_channels < 0: + s = "Upsample conv layers were specified while local conditioning disabled. " + s += "Notice that upsample conv layers will never be used." + warn(s) + + upsample_params = hparams.upsample_params + upsample_params["cin_channels"] = hparams.cin_channels + upsample_params["cin_pad"] = hparams.cin_pad + model = WaveNet( + out_channels=hparams.out_channels, + layers=hparams.layers, + stacks=hparams.stacks, + residual_channels=hparams.residual_channels, + gate_channels=hparams.gate_channels, + skip_out_channels=hparams.skip_out_channels, + cin_channels=hparams.cin_channels, + gin_channels=hparams.gin_channels, + n_speakers=hparams.n_speakers, + dropout=hparams.dropout, + kernel_size=hparams.kernel_size, + cin_pad=hparams.cin_pad, + upsample_conditional_features=hparams.upsample_conditional_features, + upsample_params=upsample_params, + scalar_input=is_scalar_input(hparams.input_type), + output_distribution=hparams.output_distribution, + ) + + Net = PredictNet(model) + Net.set_train(False) +# if target != "Ascend": +# receptive_field = model.receptive_field +# print("Receptive field (samples / ms): {} / {}".format(receptive_field, receptive_field / fs * 1000)) + param_dict = load_checkpoint(args.pretrain_ckpt) + load_param_into_net(model, param_dict) + print('Successfully loading the pre-trained model') + + if is_mulaw_quantize(hparams.input_type): + x = np.array(np.random.random((2, 256, 10240)), dtype=np.float32) + c = np.array(np.random.random((2, 80, 44)), dtype=np.float32) + else: + x = np.array(np.random.random((2, 1, 4096)), dtype=np.float32) + c = np.array(np.random.random((2, 80, 20)), dtype=np.float32) + g = np.array([0, 0], dtype=np.int64) + + export(Net, Tensor(x), Tensor(c), Tensor(g), file_name="WaveNet", file_format='MINDIR') diff --git a/HuaWeiExperiment/wavenet/hparams.py b/HuaWeiExperiment/wavenet/hparams.py new file mode 100644 index 00000000..1b03f67b --- /dev/null +++ b/HuaWeiExperiment/wavenet/hparams.py @@ -0,0 +1,133 @@ +from wavenet_vocoder.tfcompat.hparam import HParams +import numpy as np + +# NOTE: If you want full control for model architecture. please take a look +# at the code and change whatever you want. Some hyper parameters are hardcoded. + +# Default hyperparameters: +hparams = HParams( + name="wavenet_vocoder", + + # Input type: + # 1. raw [-1, 1] + # 2. mulaw [-1, 1] + # 3. mulaw-quantize [0, mu] + # If input_type is raw or mulaw, network assumes scalar input and + # discretized mixture of logistic distributions output, otherwise one-hot + # input and softmax output are assumed. + # **NOTE**: if you change the one of the two parameters below, you need to + # re-run preprocessing before training. + input_type="raw", + quantize_channels=65536, # 65536 or 256 + + # Audio: + # time-domain pre/post-processing + # e.g., preemphasis/inv_preemphasis + # ref: LPCNet https://arxiv.org/abs/1810.11846 + preprocess="", + postprocess="", + # waveform domain scaling + global_gain_scale=1.0, + + sample_rate=22050, + # this is only valid for mulaw is True + silence_threshold=2, + num_mels=80, + fmin=125, + fmax=7600, + fft_size=1024, + # shift can be specified by either hop_size or frame_shift_ms + hop_size=256, + frame_shift_ms=None, + win_length=1024, + win_length_ms=-1.0, + window="hann", + + # DC removal + highpass_cutoff=70.0, + + # Parametric output distribution type for scalar input + # 1) Logistic or 2) Normal + output_distribution="Logistic", + log_scale_min=-16.0, + + # Model: + # This should equal to `quantize_channels` if mu-law quantize enabled + # otherwise num_mixture * 3 (pi, mean, log_scale) + # single mixture case: 2 + out_channels=10 * 3, + layers=24, + stacks=4, + residual_channels=128, + gate_channels=256, # split into 2 gropus internally for gated activation + skip_out_channels=128, + dropout=0.0, + kernel_size=3, + + # Local conditioning (set negative value to disable)) + cin_channels=80, + cin_pad=2, + # If True, use transposed convolutions to upsample conditional features, + # otherwise repeat features to adjust time resolution + upsample_conditional_features=True, + upsample_net="ConvInUpsampleNetwork", + upsample_params={ + "upsample_scales": [4, 4, 4, 4], # should np.prod(upsample_scales) == hop_size + }, + + # Global conditioning (set negative value to disable) + # currently limited for speaker embedding + # this should only be enabled for multi-speaker dataset + gin_channels=-1, # i.e., speaker embedding dim + n_speakers=7, # 7 for CMU ARCTIC + + # Data loader + pin_memory=True, + num_workers=2, + + # Loss + + # Training: + batch_size=8, + optimizer="Adam", + optimizer_params={ + "lr": 1e-3, + "eps": 1e-8, + "weight_decay": 0.0, + }, + + # see lrschedule.py for available lr_schedule + lr_schedule="step_learning_rate_decay", + lr_schedule_kwargs={"anneal_rate": 0.5, "anneal_interval": 200000}, + + max_train_steps=1000000, + nepochs=2000, + + clip_thresh=-1, + + # max time steps can either be specified as sec or steps + # if both are None, then full audio samples are used in a batch + max_time_sec=None, + max_time_steps=10240, # 256 * 40 + + # Hold moving averaged parameters and use them for evaluation + exponential_moving_average=True, + # averaged = decay * averaged + (1 - decay) * x + ema_decay=0.9999, + + # Save + # per-step intervals + checkpoint_interval=100000, + train_eval_interval=100000, + # per-epoch interval + test_eval_epoch_interval=50, + save_optimizer_state=True, + + # Eval: +) + + +def hparams_debug_string(): + values = hparams.values() + hp = [' %s: %s' % (name, values[name]) for name in sorted(values)] + return 'Hyperparameters:\n' + '\n'.join(hp) diff --git a/HuaWeiExperiment/wavenet/mksubset.py b/HuaWeiExperiment/wavenet/mksubset.py new file mode 100644 index 00000000..62b15716 --- /dev/null +++ b/HuaWeiExperiment/wavenet/mksubset.py @@ -0,0 +1,159 @@ +# coding: utf-8 +""" +Make subset of dataset + +usage: mksubset.py [options] + +options: + -h, --help Show help message. + --limit= Limit dataset size by N-hours [default: 10000]. + --train-dev-test-split Train/test split. + --dev-size= Development size or rate [default: 0.1]. + --test-size= Test size or rate [default: 0.1]. + --target-sr= Resampling. + --random-state= Random seed [default: 1234]. +""" +from docopt import docopt +import librosa +from glob import glob +from os.path import join, basename, exists, splitext +from tqdm import tqdm +import sys +import os +from shutil import copy2 +from scipy.io import wavfile +import numpy as np + + +def read_wav_or_raw(src_file, is_raw): + if is_raw: + sr = 24000 # hard coded for now + x = np.fromfile(src_file, dtype=np.int16) + else: + sr, x = wavfile.read(src_file) + return sr, x + + +def write_wav_or_raw(dst_path, sr, x, is_raw): + if is_raw: + x.tofile(dst_path) + else: + wavfile.write(dst_path, sr, x) + +if __name__ == "__main__": + args = docopt(__doc__) + in_dir = args[""] + out_dir = args[""] + limit = float(args["--limit"]) + train_dev_test_split = args["--train-dev-test-split"] + dev_size = float(args["--dev-size"]) + test_size = float(args["--test-size"]) + target_sr = args["--target-sr"] + target_sr = int(target_sr) if target_sr is not None else None + random_state = int(args["--random-state"]) + + src_files = sorted(glob(join(in_dir, "*.wav"))) + raw_files = sorted(glob(join(in_dir, "*.raw"))) + is_raw = len(src_files) == 0 and len(raw_files) > 0 + if is_raw: + print("Assuming 24kHz /16bit audio data") + src_files = raw_files + if len(src_files) == 0: + raise RuntimeError("No files found in {}".format(in_dir)) + + total_samples = 0 + indices = [] + signed_int16_max = 2**15 + + os.makedirs(out_dir, exist_ok=True) + if train_dev_test_split: + os.makedirs(join(out_dir, "train_no_dev"), exist_ok=True) + os.makedirs(join(out_dir, "dev"), exist_ok=True) + os.makedirs(join(out_dir, "eval"), exist_ok=True) + + print("Total number of utterances: {}".format(len(src_files))) + for idx, src_file in tqdm(enumerate(src_files)): + sr, x = read_wav_or_raw(src_file, is_raw) + if x.dtype == np.int16: + x = x.astype(np.float32) / signed_int16_max + total_samples += len(x) + total_hours = float(total_samples) / sr / 3600.0 + indices.append(idx) + + if total_hours > limit: + print("Total hours {:.3f} exceeded limit ({} hours).".format(total_hours, limit)) + break + print("Total number of collected utterances: {}".format(len(indices))) + + if train_dev_test_split: + from sklearn.model_selection import train_test_split as split + # Get test and dev set from last + if test_size > 1 and dev_size > 1: + test_size = int(test_size) + dev_size = int(dev_size) + testdev_size = test_size + dev_size + train_indices = indices[:-testdev_size] + dev_indices = indices[-testdev_size:-testdev_size + dev_size] + test_indices = indices[-test_size:] + else: + train_indices, dev_test_indices = split( + indices, test_size=test_size + dev_size, random_state=random_state) + dev_indices, test_indices = split( + dev_test_indices, test_size=test_size / (test_size + dev_size), + random_state=random_state) + sets = [ + (sorted(train_indices), join(out_dir, "train_no_dev")), + (sorted(dev_indices), join(out_dir, "dev")), + (sorted(test_indices), join(out_dir, "eval")), + ] + else: + sets = [(indices, out_dir)] + + from sklearn.preprocessing import MinMaxScaler + scaler = MinMaxScaler() + + total_samples = {} + sr = 0 + for indices, d in sets: + set_name = basename(d) + total_samples[set_name] = 0 + for idx in tqdm(indices): + src_file = src_files[idx] + dst_path = join(d, basename(src_file)) + if target_sr is not None: + sr, x = read_wav_or_raw(src_file, is_raw) + is_int16 = x.dtype == np.int16 + if is_int16: + x = x.astype(np.float32) / signed_int16_max + if target_sr is not None and target_sr != sr: + x = librosa.resample(x, sr, target_sr) + sr = target_sr + scaler.partial_fit(x.astype(np.float64).reshape(-1, 1)) + if is_int16: + x = (x * signed_int16_max).astype(np.int16) + write_wav_or_raw(dst_path, sr, x, is_raw) + total_samples[set_name] += len(x) + else: + sr, x = read_wav_or_raw(src_file, is_raw) + is_int16 = x.dtype == np.int16 + if is_int16: + x = x.astype(np.float32) / signed_int16_max + scaler.partial_fit(x.astype(np.float64).reshape(-1, 1)) + total_samples[set_name] += len(x) + copy2(src_file, dst_path) + + print("Waveform min: {}".format(scaler.data_min_)) + print("Waveform max: {}".format(scaler.data_max_)) + absmax = max(np.abs(scaler.data_min_[0]), np.abs(scaler.data_max_[0])) + print("Waveform absolute max: {}".format(absmax)) + if absmax > 1.0: + print("There were clipping(s) in your dataset.") + print("Global scaling factor would be around {}".format(1.0 / absmax)) + + if train_dev_test_split: + print("Train/dev/test split:") + for n, s in zip(["train_no_dev", "dev", "eval"], sets): + hours = total_samples[n] / sr / 3600.0 + print("{}: {:.2f} hours ({} utt)".format(n, hours, len(s[0]))) + + sys.exit(0) diff --git a/HuaWeiExperiment/wavenet/preprocess.py b/HuaWeiExperiment/wavenet/preprocess.py new file mode 100644 index 00000000..08973766 --- /dev/null +++ b/HuaWeiExperiment/wavenet/preprocess.py @@ -0,0 +1,72 @@ +# coding: utf-8 +""" +Preprocess dataset + +usage: preprocess.py [options] + +options: + --num_workers= Num workers. + --hparams= Hyper parameters [default: ]. + --preset= Path of preset parameters (json). + -h, --help Show help message. +""" +from docopt import docopt +import os +import sys +from os.path import join +from multiprocessing import cpu_count +from tqdm import tqdm +import importlib +from hparams import hparams + + +def preprocess(mod, in_dir, out_root, num_workers): + os.makedirs(out_dir, exist_ok=True) + metadata = mod.build_from_path(in_dir, out_dir, num_workers, tqdm=tqdm) + write_metadata(metadata, out_dir) + + +def write_metadata(metadata, out_dir): + with open(os.path.join(out_dir, 'train.txt'), 'w', encoding='utf-8') as f: + for m in metadata: + f.write('|'.join([str(x) for x in m]) + '\n') + frames = sum([m[2] for m in metadata]) + sr = hparams.sample_rate + hours = frames / sr / 3600 + print('Wrote %d utterances, %d time steps (%.2f hours)' % (len(metadata), frames, hours)) + print('Min frame length: %d' % min(m[2] for m in metadata)) + print('Max frame length: %d' % max(m[2] for m in metadata)) + + +if __name__ == "__main__": + args = docopt(__doc__) + name = args[""] + in_dir = args[""] + out_dir = args[""] + num_workers = args["--num_workers"] + num_workers = cpu_count() // 2 if num_workers is None else int(num_workers) + preset = args["--preset"] + + # Load preset if specified + if preset is not None: + with open(preset) as f: + hparams.parse_json(f.read()) + # Override hyper parameters + hparams.parse(args["--hparams"]) + assert hparams.name == "wavenet_vocoder" + + print("Sampling frequency: {}".format(hparams.sample_rate)) + if name in ["cmu_arctic", "jsut", "librivox"]: + print("""warn!: {} is no longer explicitly supported! + +Please use a generic dataest 'wavallin' instead. +All you need to do is to put all wav files in a single directory.""".format(name)) + sys.exit(1) + + if name == "ljspeech": + print("""warn: ljspeech is deprecated! +Please use a generic dataset 'wavallin' instead.""") + sys.exit(1) + + mod = importlib.import_module("datasets." + name) + preprocess(mod, in_dir, out_dir, num_workers) diff --git a/HuaWeiExperiment/wavenet/preprocess_normalize.py b/HuaWeiExperiment/wavenet/preprocess_normalize.py new file mode 100644 index 00000000..35c4bd80 --- /dev/null +++ b/HuaWeiExperiment/wavenet/preprocess_normalize.py @@ -0,0 +1,79 @@ +# coding: utf-8 +"""Perform meanvar normalization to preprocessed features. + +usage: preprocess_normalize.py [options] + +options: + --inverse Inverse transform. + --num_workers= Num workers. + -h, --help Show help message. +""" +from docopt import docopt +import os +from os.path import join, exists, basename, splitext +from multiprocessing import cpu_count +from tqdm import tqdm +from nnmnkwii import preprocessing as P +import numpy as np +import json +from concurrent.futures import ProcessPoolExecutor +from functools import partial +from shutil import copyfile + +import joblib +from glob import glob +from itertools import zip_longest + + +def get_paths_by_glob(in_dir, filt): + return glob(join(in_dir, filt)) + + +def _process_utterance(out_dir, audio_path, feat_path, scaler, inverse): + # [Optional] copy audio with the same name if exists + if audio_path is not None and exists(audio_path): + name = splitext(basename(audio_path))[0] + np.save(join(out_dir, name), np.load(audio_path), allow_pickle=False) + + # [Required] apply normalization for features + assert exists(feat_path) + x = np.load(feat_path) + if inverse: + y = scaler.inverse_transform(x) + else: + y = scaler.transform(x) + assert x.dtype == y.dtype + name = splitext(basename(feat_path))[0] + np.save(join(out_dir, name), y, allow_pickle=False) + + +def apply_normalization_dir2dir(in_dir, out_dir, scaler, inverse, num_workers): + # NOTE: at this point, audio_paths can be empty + audio_paths = get_paths_by_glob(in_dir, "*-wave.npy") + feature_paths = get_paths_by_glob(in_dir, "*-feats.npy") + executor = ProcessPoolExecutor(max_workers=num_workers) + futures = [] + for audio_path, feature_path in zip_longest(audio_paths, feature_paths): + futures.append(executor.submit( + partial(_process_utterance, out_dir, audio_path, feature_path, scaler, inverse))) + for future in tqdm(futures): + future.result() + + +if __name__ == "__main__": + args = docopt(__doc__) + in_dir = args[""] + out_dir = args[""] + scaler_path = args[""] + scaler = joblib.load(scaler_path) + inverse = args["--inverse"] + num_workers = args["--num_workers"] + num_workers = cpu_count() // 2 if num_workers is None else int(num_workers) + + os.makedirs(out_dir, exist_ok=True) + apply_normalization_dir2dir(in_dir, out_dir, scaler, inverse, num_workers) + + # Copy meta information if exists + traintxt = join(in_dir, "train.txt") + if exists(traintxt): + copyfile(join(in_dir, "train.txt"), join(out_dir, "train.txt")) diff --git a/HuaWeiExperiment/wavenet/requirements.txt b/HuaWeiExperiment/wavenet/requirements.txt new file mode 100644 index 00000000..be2c8c3e --- /dev/null +++ b/HuaWeiExperiment/wavenet/requirements.txt @@ -0,0 +1,7 @@ +numpy +audio +nnmnkwii +docopt +scipy +hparams +librosa diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0269_gen.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0269_gen.wav new file mode 100644 index 00000000..26f5746f Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0269_gen.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0269_ref.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0269_ref.wav new file mode 100644 index 00000000..e3eddd30 Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0269_ref.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0270_gen.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0270_gen.wav new file mode 100644 index 00000000..19b12cec Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0270_gen.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0270_ref.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0270_ref.wav new file mode 100644 index 00000000..b0c0c879 Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0270_ref.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0271_gen.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0271_gen.wav new file mode 100644 index 00000000..851d8ba0 Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0271_gen.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0271_ref.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0271_ref.wav new file mode 100644 index 00000000..1642d166 Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0271_ref.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0272_gen.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0272_gen.wav new file mode 100644 index 00000000..605e978c Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0272_gen.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0272_ref.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0272_ref.wav new file mode 100644 index 00000000..e2c9a77e Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0272_ref.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0273_gen.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0273_gen.wav new file mode 100644 index 00000000..bfd5fcbb Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0273_gen.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0273_ref.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0273_ref.wav new file mode 100644 index 00000000..b0d68c58 Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0273_ref.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0274_gen.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0274_gen.wav new file mode 100644 index 00000000..832c3de3 Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0274_gen.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0274_ref.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0274_ref.wav new file mode 100644 index 00000000..ebfbd07d Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0274_ref.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0275_gen.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0275_gen.wav new file mode 100644 index 00000000..8bac648c Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0275_gen.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0275_ref.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0275_ref.wav new file mode 100644 index 00000000..2a25d2e1 Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0275_ref.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0276_gen.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0276_gen.wav new file mode 100644 index 00000000..ac1e4e7c Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0276_gen.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0276_ref.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0276_ref.wav new file mode 100644 index 00000000..63cc15f5 Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0276_ref.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0277_gen.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0277_gen.wav new file mode 100644 index 00000000..6470ed7e Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0277_gen.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0277_ref.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0277_ref.wav new file mode 100644 index 00000000..3554fa95 Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0277_ref.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0278_gen.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0278_gen.wav new file mode 100644 index 00000000..0fd54f5e Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0278_gen.wav differ diff --git a/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0278_ref.wav b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0278_ref.wav new file mode 100644 index 00000000..90305869 Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveAudio/speaker[0 0 0 0 0 0 0 0 0 0]_LJ050-0278_ref.wav differ diff --git a/HuaWeiExperiment/wavenet/saveCheckpoint/hparams.json b/HuaWeiExperiment/wavenet/saveCheckpoint/hparams.json new file mode 100644 index 00000000..a2ddabce --- /dev/null +++ b/HuaWeiExperiment/wavenet/saveCheckpoint/hparams.json @@ -0,0 +1,69 @@ +{ + "name": "wavenet_vocoder", + "input_type": "raw", + "quantize_channels": 65536, + "preprocess": "preemphasis", + "postprocess": "inv_preemphasis", + "global_gain_scale": 0.55, + "sample_rate": 22050, + "silence_threshold": 2, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "frame_shift_ms": null, + "win_length": 1024, + "win_length_ms": -1.0, + "window": "hann", + "highpass_cutoff": 70.0, + "output_distribution": "Normal", + "log_scale_min": -16.0, + "out_channels": 2, + "layers": 24, + "stacks": 4, + "residual_channels": 128, + "gate_channels": 256, + "skip_out_channels": 128, + "dropout": 0.0, + "kernel_size": 3, + "cin_channels": 80, + "cin_pad": 2, + "upsample_conditional_features": true, + "upsample_net": "ConvInUpsampleNetwork", + "upsample_params": { + "upsample_scales": [ + 4, + 4, + 4, + 4 + ] + }, + "gin_channels": -1, + "n_speakers": 7, + "pin_memory": true, + "num_workers": 2, + "batch_size": 8, + "optimizer": "Adam", + "optimizer_params": { + "lr": 0.001, + "eps": 1e-08, + "weight_decay": 0.0 + }, + "lr_schedule": "step_learning_rate_decay", + "lr_schedule_kwargs": { + "anneal_rate": 0.5, + "anneal_interval": 200000 + }, + "max_train_steps": 1000000, + "nepochs": 1, + "clip_thresh": -1, + "max_time_sec": null, + "max_time_steps": 10240, + "exponential_moving_average": true, + "ema_decay": 0.9999, + "checkpoint_interval": 100000, + "train_eval_interval": 100000, + "test_eval_epoch_interval": 50, + "save_optimizer_state": true +} \ No newline at end of file diff --git a/HuaWeiExperiment/wavenet/saveCheckpoint/wavenet-1_1635.ckpt b/HuaWeiExperiment/wavenet/saveCheckpoint/wavenet-1_1635.ckpt new file mode 100644 index 00000000..b2c02a3e Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveCheckpoint/wavenet-1_1635.ckpt differ diff --git a/HuaWeiExperiment/wavenet/saveCheckpoint/wavenet-graph.meta b/HuaWeiExperiment/wavenet/saveCheckpoint/wavenet-graph.meta new file mode 100644 index 00000000..647124fc Binary files /dev/null and b/HuaWeiExperiment/wavenet/saveCheckpoint/wavenet-graph.meta differ diff --git a/HuaWeiExperiment/wavenet/saveConvert/hparams.json b/HuaWeiExperiment/wavenet/saveConvert/hparams.json new file mode 100644 index 00000000..a2ddabce --- /dev/null +++ b/HuaWeiExperiment/wavenet/saveConvert/hparams.json @@ -0,0 +1,69 @@ +{ + "name": "wavenet_vocoder", + "input_type": "raw", + "quantize_channels": 65536, + "preprocess": "preemphasis", + "postprocess": "inv_preemphasis", + "global_gain_scale": 0.55, + "sample_rate": 22050, + "silence_threshold": 2, + "num_mels": 80, + "fmin": 125, + "fmax": 7600, + "fft_size": 1024, + "hop_size": 256, + "frame_shift_ms": null, + "win_length": 1024, + "win_length_ms": -1.0, + "window": "hann", + "highpass_cutoff": 70.0, + "output_distribution": "Normal", + "log_scale_min": -16.0, + "out_channels": 2, + "layers": 24, + "stacks": 4, + "residual_channels": 128, + "gate_channels": 256, + "skip_out_channels": 128, + "dropout": 0.0, + "kernel_size": 3, + "cin_channels": 80, + "cin_pad": 2, + "upsample_conditional_features": true, + "upsample_net": "ConvInUpsampleNetwork", + "upsample_params": { + "upsample_scales": [ + 4, + 4, + 4, + 4 + ] + }, + "gin_channels": -1, + "n_speakers": 7, + "pin_memory": true, + "num_workers": 2, + "batch_size": 8, + "optimizer": "Adam", + "optimizer_params": { + "lr": 0.001, + "eps": 1e-08, + "weight_decay": 0.0 + }, + "lr_schedule": "step_learning_rate_decay", + "lr_schedule_kwargs": { + "anneal_rate": 0.5, + "anneal_interval": 200000 + }, + "max_train_steps": 1000000, + "nepochs": 1, + "clip_thresh": -1, + "max_time_sec": null, + "max_time_steps": 10240, + "exponential_moving_average": true, + "ema_decay": 0.9999, + "checkpoint_interval": 100000, + "train_eval_interval": 100000, + "test_eval_epoch_interval": 50, + "save_optimizer_state": true +} \ No newline at end of file diff --git a/HuaWeiExperiment/wavenet/scripts/eval.log b/HuaWeiExperiment/wavenet/scripts/eval.log new file mode 100644 index 00000000..a33aa47a --- /dev/null +++ b/HuaWeiExperiment/wavenet/scripts/eval.log @@ -0,0 +1 @@ +E:\anaconda3\envs\HuaWeiExperiment\python.exe: can't open file 'E:\python\pythonProjects\HuaWeiExperiment\wavenet\scripts\evaluate.py': [Errno 2] No such file or directory diff --git a/HuaWeiExperiment/wavenet/scripts/run_distribute_train_ascend.sh b/HuaWeiExperiment/wavenet/scripts/run_distribute_train_ascend.sh new file mode 100644 index 00000000..c284d492 --- /dev/null +++ b/HuaWeiExperiment/wavenet/scripts/run_distribute_train_ascend.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +ROOT_PATH=$(pwd) +train_file=$1 +DATA_DIR=$2 +PRESET=$3 +CKPT_DIR=$4 +export RANK_TABLE_FILE=$5 +export HCCL_CONNECT_TIMEOUT=600 +export RANK_SIZE=$6 +begin=$7 +for((i=begin;ilog$i.log 2>&1 & +done + diff --git a/HuaWeiExperiment/wavenet/scripts/run_distribute_train_gpu.sh b/HuaWeiExperiment/wavenet/scripts/run_distribute_train_gpu.sh new file mode 100644 index 00000000..f0ac0689 --- /dev/null +++ b/HuaWeiExperiment/wavenet/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DATA_PATH=$1 +PRESET=$2 +CHECKPOINT_DIR=$3 +mpirun --allow-run-as-root -n 8 --output-filename log_output --merge-stderr-to-stdout \ + python ./train.py --data_path=$DATA_PATH --preset=$PRESET --checkpoint_dir=$CHECKPOINT_DIR --is_distributed > train.log 2>&1 & diff --git a/HuaWeiExperiment/wavenet/scripts/run_eval_ascend.sh b/HuaWeiExperiment/wavenet/scripts/run_eval_ascend.sh new file mode 100644 index 00000000..82943609 --- /dev/null +++ b/HuaWeiExperiment/wavenet/scripts/run_eval_ascend.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +ROOT_PATH=$(pwd) +export DEVICE_ID=$1 +if [ $# == 7 ] +then + python3 ${ROOT_PATH}/$2 --data_path $3 --preset $4 \ +--platform=Ascend --pretrain_ckpt $5 --is_numpy --output_path $7 >log_eval.log 2>&1 & +else + python3 ${ROOT_PATH}/$2 --data_path $3 --preset $4 \ +--platform=Ascend --pretrain_ckpt $5 --output_path $6 >log_eval.log 2>&1 & +fi \ No newline at end of file diff --git a/HuaWeiExperiment/wavenet/scripts/run_eval_cpu.sh b/HuaWeiExperiment/wavenet/scripts/run_eval_cpu.sh new file mode 100644 index 00000000..5c43192a --- /dev/null +++ b/HuaWeiExperiment/wavenet/scripts/run_eval_cpu.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# == 5 ] +then + python ./evaluate.py --data_path=$1 --preset=$2 --pretrain_ckpt=$3 --is_numpy --output_path=$5 --platform=CPU \ + > eval.log 2>&1 & +else + python ./evaluate.py --data_path=$1 --preset=$2 --pretrain_ckpt=$3 --output_path=$4 --platform=CPU > eval.log 2>&1 & +fi diff --git a/HuaWeiExperiment/wavenet/scripts/run_eval_gpu.sh b/HuaWeiExperiment/wavenet/scripts/run_eval_gpu.sh new file mode 100644 index 00000000..42ea8f0c --- /dev/null +++ b/HuaWeiExperiment/wavenet/scripts/run_eval_gpu.sh @@ -0,0 +1,23 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +if [ $# == 6 ] +then + CUDA_VISIBLE_DEVICES=$1 python ../evaluate.py --data_path=$2 --preset=$3 --pretrain_ckpt=$4 \ + --is_numpy --output_path=$6 > eval.log 2>&1 & +else + CUDA_VISIBLE_DEVICES=$1 python ../evaluate.py --data_path=$2 --preset=$3 --pretrain_ckpt=$4 \ + --output_path=$5 > eval.log 2>&1 & +fi diff --git a/HuaWeiExperiment/wavenet/scripts/run_standalone_train_ascend.sh b/HuaWeiExperiment/wavenet/scripts/run_standalone_train_ascend.sh new file mode 100644 index 00000000..311bf54a --- /dev/null +++ b/HuaWeiExperiment/wavenet/scripts/run_standalone_train_ascend.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +ROOT_PATH=$(pwd) +train_file=$1 +DATA_DIR=$2 +PRESET=$3 +CKPT_DIR=$4 +export DEVICE_ID=$5 +export RANK_ID=0 +export RANK_SIZE=1 +python3 ${ROOT_PATH}/${train_file} --data_path $DATA_DIR --preset $PRESET \ +--platform=Ascend --checkpoint_dir $CKPT_DIR >log_train.log 2>&1 & diff --git a/HuaWeiExperiment/wavenet/scripts/run_standalone_train_cpu.sh b/HuaWeiExperiment/wavenet/scripts/run_standalone_train_cpu.sh new file mode 100644 index 00000000..948952ba --- /dev/null +++ b/HuaWeiExperiment/wavenet/scripts/run_standalone_train_cpu.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DATA_PATH=$1 +PRESET=$2 +CHECKPOINT_DIR=$3 +python ./train.py --data_path=$DATA_PATH --preset=$PRESET --checkpoint_dir=$CHECKPOINT_DIR \ +--platform=CPU > train.log 2>&1 & diff --git a/HuaWeiExperiment/wavenet/scripts/run_standalone_train_gpu.sh b/HuaWeiExperiment/wavenet/scripts/run_standalone_train_gpu.sh new file mode 100644 index 00000000..b9aae15a --- /dev/null +++ b/HuaWeiExperiment/wavenet/scripts/run_standalone_train_gpu.sh @@ -0,0 +1,21 @@ +#!/bin/bash +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +DEVICE_ID=$1 +DATA_PATH=$2 +PRESET=$3 +CHECKPOINT_DIR=$4 +CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --data_path=$DATA_PATH --preset=$PRESET \ +--checkpoint_dir=$CHECKPOINT_DIR > train.log 2>&1 & diff --git a/HuaWeiExperiment/wavenet/scripts/train.log b/HuaWeiExperiment/wavenet/scripts/train.log new file mode 100644 index 00000000..521f7e8f --- /dev/null +++ b/HuaWeiExperiment/wavenet/scripts/train.log @@ -0,0 +1 @@ +E:\anaconda3\envs\HuaWeiExperiment\python.exe: can't open file 'E:\python\pythonProjects\HuaWeiExperiment\wavenet\scripts\train.py': [Errno 2] No such file or directory diff --git a/HuaWeiExperiment/wavenet/src/__init__.py b/HuaWeiExperiment/wavenet/src/__init__.py new file mode 100644 index 00000000..1e5f7fbe --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the License); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# httpwww.apache.orglicensesLICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/HuaWeiExperiment/wavenet/src/audio.py b/HuaWeiExperiment/wavenet/src/audio.py new file mode 100644 index 00000000..663516f2 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/audio.py @@ -0,0 +1,173 @@ +import librosa +import librosa.filters +import numpy as np +from hparams import hparams +from scipy.io import wavfile +from nnmnkwii import preprocessing as P + + +def low_cut_filter(x, fs, cutoff=70): + """APPLY LOW CUT FILTER. + + https://github.com/kan-bayashi/PytorchWaveNetVocoder + + Args: + x (ndarray): Waveform sequence. + fs (int): Sampling frequency. + cutoff (float): Cutoff frequency of low cut filter. + Return: + ndarray: Low cut filtered waveform sequence. + """ + nyquist = fs // 2 + norm_cutoff = cutoff / nyquist + from scipy.signal import firwin, lfilter + + # low cut filter + fil = firwin(255, norm_cutoff, pass_zero=False) + lcf_x = lfilter(fil, 1, x) + + return lcf_x + + +def load_wav(path): + sr, x = wavfile.read(path) + signed_int16_max = 2**15 + if x.dtype == np.int16: + x = x.astype(np.float32) / signed_int16_max + if sr != hparams.sample_rate: + x = librosa.resample(x, sr, hparams.sample_rate) + x = np.clip(x, -1.0, 1.0) + return x + + +def save_wav(wav, path): + wav *= 32767 / max(0.01, np.max(np.abs(wav))) + wavfile.write(path, hparams.sample_rate, wav.astype(np.int16)) + + +def trim(quantized): + start, end = start_and_end_indices(quantized, hparams.silence_threshold) + return quantized[start:end] + + +def preemphasis(x, coef=0.85): + return P.preemphasis(x, coef) + + +def inv_preemphasis(x, coef=0.85): + return P.inv_preemphasis(x, coef) + + +def adjust_time_resolution(quantized, mel): + """Adjust time resolution by repeating features + + Args: + quantized (ndarray): (T,) + mel (ndarray): (N, D) + + Returns: + tuple: Tuple of (T,) and (T, D) + """ + assert len(quantized.shape) == 1 + assert len(mel.shape) == 2 + + upsample_factor = quantized.size // mel.shape[0] + mel = np.repeat(mel, upsample_factor, axis=0) + n_pad = quantized.size - mel.shape[0] + if n_pad != 0: + assert n_pad > 0 + mel = np.pad(mel, [(0, n_pad), (0, 0)], mode="constant", constant_values=0) + + # trim + start, end = start_and_end_indices(quantized, hparams.silence_threshold) + + return quantized[start:end], mel[start:end, :] + + +def start_and_end_indices(quantized, silence_threshold=2): + for start in range(quantized.size): + if abs(quantized[start] - 127) > silence_threshold: + break + for end in range(quantized.size - 1, 1, -1): + if abs(quantized[end] - 127) > silence_threshold: + break + + assert abs(quantized[start] - 127) > silence_threshold + assert abs(quantized[end] - 127) > silence_threshold + + return start, end + + +def logmelspectrogram(y, pad_mode="reflect"): + """Same log-melspectrogram computation as espnet + https://github.com/espnet/espnet + from espnet.transform.spectrogram import logmelspectrogram + """ + D = _stft(y, pad_mode=pad_mode) + S = _linear_to_mel(np.abs(D)) + S = np.log10(np.maximum(S, 1e-10)) + return S + + +def get_hop_size(): + hop_size = hparams.hop_size + if hop_size is None: + assert hparams.frame_shift_ms is not None + hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate) + return hop_size + + +def get_win_length(): + win_length = hparams.win_length + if win_length < 0: + assert hparams.win_length_ms > 0 + win_length = int(hparams.win_length_ms / 1000 * hparams.sample_rate) + return win_length + + +def _stft(y, pad_mode="constant"): + # use constant padding (defaults to zeros) instead of reflection padding + return librosa.stft(y=y, n_fft=hparams.fft_size, hop_length=get_hop_size(), + win_length=get_win_length(), window=hparams.window, + pad_mode=pad_mode) + + +def pad_lr(x, fsize, fshift): + return (0, fsize) + +# Conversions: + + +_mel_basis = None + + +def _linear_to_mel(spectrogram): + global _mel_basis + if _mel_basis is None: + _mel_basis = _build_mel_basis() + return np.dot(_mel_basis, spectrogram) + + +def _build_mel_basis(): + if hparams.fmax is not None: + assert hparams.fmax <= hparams.sample_rate // 2 + return librosa.filters.mel(hparams.sample_rate, hparams.fft_size, + fmin=hparams.fmin, fmax=hparams.fmax, + n_mels=hparams.num_mels) + + +def _amp_to_db(x): + min_level = np.exp(hparams.min_level_db / 20 * np.log(10)) + return 20 * np.log10(np.maximum(min_level, x)) + + +def _db_to_amp(x): + return np.power(10.0, x * 0.05) + + +def _normalize(S): + return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1) + + +def _denormalize(S): + return (np.clip(S, 0, 1) * -hparams.min_level_db) + hparams.min_level_db diff --git a/HuaWeiExperiment/wavenet/src/callback.py b/HuaWeiExperiment/wavenet/src/callback.py new file mode 100644 index 00000000..c2f62dee --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/callback.py @@ -0,0 +1,103 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Defined callback for WaveNet. +""" +import time +from mindspore.train.callback import Callback +from mindspore import Tensor +import numpy as np + + +class TimeMonitor(Callback): + """ + Time monitor for calculating cost of each epoch. + + Args: + data_size (int): step size of an epoch. + """ + + def __init__(self, data_size): + super(TimeMonitor, self).__init__() + self.data_size = data_size + + def epoch_begin(self, run_context): + self.epoch_time = time.time() + + def epoch_end(self, run_context): + epoch_mseconds = (time.time() - self.epoch_time) * 1000 + per_step_mseconds = epoch_mseconds / self.data_size + print("epoch time: {0}, per step time: {1}".format(epoch_mseconds, per_step_mseconds), flush=True) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + step_mseconds = (time.time() - self.step_time) * 1000 + print(f"step time {step_mseconds}", flush=True) + + +class Monitor(Callback): + """ + Monitor loss and time. + + Args: + lr_init (numpy array): train lr + + Returns: + None + + Examples: + >>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy()) + """ + + def __init__(self, lr_init=None): + super(Monitor, self).__init__() + self.lr_init = lr_init + self.lr_init_len = len(lr_init) + + def epoch_begin(self, run_context): + self.losses = [] + self.epoch_time = time.time() + + def epoch_end(self, run_context): + cb_params = run_context.original_args() + + epoch_mseconds = (time.time() - self.epoch_time) + per_step_mseconds = epoch_mseconds / cb_params.batch_num + print("epoch time: {:5.3f}, per step time: {:5.3f}, avg loss: {:5.6f}".format(epoch_mseconds, + per_step_mseconds, + np.mean(self.losses))) + + def step_begin(self, run_context): + self.step_time = time.time() + + def step_end(self, run_context): + """step end""" + cb_params = run_context.original_args() + step_mseconds = (time.time() - self.step_time) + step_loss = cb_params.net_outputs + + if isinstance(step_loss, (tuple, list)) and isinstance(step_loss[0], Tensor): + step_loss = step_loss[0] + if isinstance(step_loss, Tensor): + step_loss = np.mean(step_loss.asnumpy()) + + self.losses.append(step_loss) + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + + print("epoch: [{:3d}/{:3d}], step:[{:5d}/{:5d}], loss:[{:5.6f}/{:5.6f}], time:[{:5.3f}], lr:[{:.9f}]".format( + cb_params.cur_epoch_num - + 1, cb_params.epoch_num, cur_step_in_epoch, cb_params.batch_num, step_loss, + np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1].asnumpy())) diff --git a/HuaWeiExperiment/wavenet/src/dataset.py b/HuaWeiExperiment/wavenet/src/dataset.py new file mode 100644 index 00000000..4b4a8f5e --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/dataset.py @@ -0,0 +1,267 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Create train dataset. +""" +import os +import math +import numpy as np +import audio +from nnmnkwii.datasets import FileSourceDataset +from nnmnkwii import preprocessing as P +from wavenet_vocoder.util import is_mulaw_quantize +from train_pytorch import _pad, _pad_2d, to_categorical, ensure_divisible, RawAudioDataSource, MelSpecDataSource, \ + assert_ready_for_upsampling +import mindspore.dataset.engine as de + + +def sequence_mask(sequence_length, max_len=None): + """make sequence mask for loss""" + if max_len is None: + max_len = np.max(sequence_length) + batch_size = len(sequence_length) + seq_range = np.linspace(0, max_len - 1, max_len, dtype=np.int32) + seq_range_expand = np.tile(np.expand_dims(seq_range, 0), (batch_size, 1)) + seq_length_expand = np.tile(np.expand_dims(sequence_length, 1), (1, max_len)) + seq_length_expand = np.expand_dims(np.array(seq_range_expand < seq_length_expand, dtype=np.float32), -1) + return seq_length_expand + + +class DistributedSampler(): + """function to distribute and shuffle sample""" + + def __init__(self, dataset, rank, group_size, shuffle=True, seed=0): + self.dataset = dataset + self.rank = rank + self.group_size = group_size + self.dataset_len = len(self.dataset) # num steps per epoch = 1635 + self.num_samplers = int(math.ceil(self.dataset_len * 1.0 / self.group_size)) + self.total_size = self.num_samplers * self.group_size + self.shuffle = shuffle + self.seed = seed + + def __iter__(self): + if self.shuffle: + self.seed = (self.seed + 1) & 0xffffffff + np.random.seed(self.seed) + indices = np.random.permutation(self.dataset_len).tolist() + else: + indices = list(range(self.dataset_len)) + + indices += indices[:(self.total_size - len(indices))] + indices = indices[self.rank::self.group_size] + return iter(indices) + + def __len__(self): + return self.num_samplers + + +def process_condition_batch(max_time_steps, hparams, batch): + """process condition batch""" + cin_pad = hparams.cin_pad + new_batch = [] + for batch_ in batch: + x, c, g = batch_ + + if hparams.upsample_conditional_features: + assert_ready_for_upsampling(x, c, cin_pad=0) + if max_time_steps is not None: + max_steps = ensure_divisible(max_time_steps, audio.get_hop_size(), True) + + if len(x) > max_steps: + max_time_frames = max_steps // audio.get_hop_size() + s = np.random.randint(cin_pad, len(c) - max_time_frames - cin_pad) + ts = s * audio.get_hop_size() + x = x[ts:ts + audio.get_hop_size() * max_time_frames] + c = c[s - cin_pad:s + max_time_frames + cin_pad, :] + assert_ready_for_upsampling(x, c, cin_pad=cin_pad) + else: + x, c = audio.adjust_time_resolution(x, c) + if max_time_steps is not None and len(x) > max_time_steps: + s = np.random.randint(cin_pad, len(x) - max_time_steps - cin_pad) + x = x[s:s + max_time_steps] + c = c[s - cin_pad:s + max_time_steps + cin_pad, :] + assert len(x) == len(c) + new_batch.append((x, c, g)) + return new_batch + + +def process_no_condition_batch(max_time_steps, batch): + """process no condition batch""" + new_batch = [] + for batch_ in batch: + x, c, g = batch_ + x = audio.trim(x) + if max_time_steps is not None and len(x) > max_time_steps: + s = np.random.randint(0, len(x) - max_time_steps) + x = x[s:s + max_time_steps] + new_batch.append((x, c, g)) + return new_batch + + +def collate_fn(batch, hparams): + """ + Create batch + """ + + local_conditioning = len(batch[0]) >= 2 and hparams.cin_channels > 0 + global_conditioning = len(batch[0]) >= 3 and hparams.gin_channels > 0 + + if hparams.max_time_sec is not None: + max_time_steps = int(hparams.max_time_sec * hparams.sample_rate) + elif hparams.max_time_steps is not None: + max_time_steps = hparams.max_time_steps + else: + max_time_steps = None + + if local_conditioning: + new_batch = process_condition_batch(max_time_steps, hparams, batch) + else: + new_batch = process_no_condition_batch(max_time_steps, batch) + batch = new_batch + + input_lengths = [len(x[0]) for x in batch] + max_input_len = max(input_lengths) + # (B, T, C) + # pad for time-axis + if is_mulaw_quantize(hparams.input_type): + padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1) + x_batch = np.array( + [_pad_2d(to_categorical(x[0], num_classes=hparams.quantize_channels), max_input_len, 0, padding_value) for x + in batch], dtype=np.float32) + else: + x_batch = np.array([_pad_2d(x[0].reshape(-1, 1), max_input_len) + for x in batch], dtype=np.float32) # pad zero to 2d wave with max input length + + assert len(x_batch.shape) == 3 + + # (B, T) + if is_mulaw_quantize(hparams.input_type): + padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1) + y_batch = np.array([_pad(x[0], max_input_len, constant_values=padding_value) + for x in batch], dtype=np.int32) + else: + y_batch = np.array([_pad(x[0], max_input_len) for x in batch], + dtype=np.float32) # pad zero to 1d wave with max input length + assert len(y_batch.shape) == 2 + + # (B, T, D) + if local_conditioning: + max_len = max([len(x[1]) for x in batch]) + c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], + dtype=np.float32) # pad zero to 2d logmel with max mel length + assert len(c_batch.shape) == 3 + # (B x C x T) + c_batch = c_batch.transpose((0, 2, 1)) + else: + c_batch = np.zeros(hparams.batch_size, dtype=np.float32) + + if global_conditioning: + g_batch = [x[2] for x in batch] + else: + # g_batch = None # Mindspore do not support None type + g_batch = np.zeros(hparams.batch_size, dtype=np.int64) + + # Convert to channel first i.e., (B, C, T) + x_batch = x_batch.transpose((0, 2, 1)) + # Add extra axis + if is_mulaw_quantize(hparams.input_type): + y_batch = np.expand_dims(y_batch, axis=-1) + else: + y_batch = np.expand_dims(y_batch, axis=-1) + + input_lengths = input_lengths + + mask = sequence_mask(input_lengths, max_len=x_batch.shape[-1]) + + return x_batch, y_batch, c_batch, g_batch, input_lengths, mask + + +class DualDataset(): + """Create Dataset loader for audio Mel and Audio""" + + def __init__(self, X, Mel, length, batch_size, hparams): + self.multi_speaker = X.file_data_source.multi_speaker + self.X = X + self.Mel = Mel + self.length = length + self.hparams = hparams + + self.sorted_index = list(np.argsort(length)) + self.bins = [self.sorted_index[i:i + batch_size] for i in range(0, len(self.sorted_index), batch_size)] + self.size = len(self.bins) # num_steps_per_epoch + + def __getitem__(self, idx): + if self.multi_speaker: + speaker_id = self.X.file_data_source.speaker_ids[idx] + else: + speaker_id = None + + combined_data = [] + mel_len, audio_len = [], [] + for i in self.bins[idx]: # one batch data + + if self.Mel is not None: + mel = self.Mel[i] + raw_audio = self.X[i] + length_mel, length_x = mel.shape[0], raw_audio.shape[0] + combined_data.append((raw_audio, mel, speaker_id)) + mel_len.append(length_mel) + audio_len.append(length_x) + else: + raw_audio = self.X[i] + length_x = raw_audio.shape[0] + combined_data.append((raw_audio, speaker_id)) + audio_len.append(length_x) + + x_batch, y_batch, c_batch, g_batch, input_lengths, mask = collate_fn(combined_data, self.hparams) + + return x_batch, y_batch, c_batch, g_batch, input_lengths, mask + + def __len__(self): + return self.size + + +def get_data_loaders(dump_root, speaker_id, hparams=None, rank_id=None, group_size=None): + """create train dataset""" + local_conditioning = hparams.cin_channels > 0 + + if hparams.max_time_steps is not None: + max_steps = ensure_divisible(hparams.max_time_steps, audio.get_hop_size(), True) + else: + max_steps = None + + X = FileSourceDataset( + RawAudioDataSource(os.path.join(dump_root, 'train_no_dev'), speaker_id=speaker_id, + max_steps=max_steps, cin_pad=hparams.cin_pad, + hop_size=audio.get_hop_size())) + + if local_conditioning: + Mel = FileSourceDataset( + MelSpecDataSource(os.path.join(dump_root, 'train_no_dev'), speaker_id=speaker_id, + max_steps=max_steps, cin_pad=hparams.cin_pad, + hop_size=audio.get_hop_size())) + assert len(X) == len(Mel) + print("Local conditioning enabled. Shape of a sample: {}.".format(Mel[0].shape)) + else: + Mel = None + print("length of the dataset is {}".format(len(X))) + length_x = np.array(X.file_data_source.lengths) + + dataset = DualDataset(X, Mel, length_x, batch_size=hparams.batch_size, hparams=hparams) + sampler = DistributedSampler(dataset, rank_id, group_size, shuffle=True, seed=0) + data_loaders = de.GeneratorDataset(dataset, ["x_batch", "y_batch", "c_batch", "g_batch", "input_lengths", "mask"], + sampler=sampler) + + return data_loaders diff --git a/HuaWeiExperiment/wavenet/src/loss.py b/HuaWeiExperiment/wavenet/src/loss.py new file mode 100644 index 00000000..612a6b82 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/loss.py @@ -0,0 +1,447 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Loss function definition. +""" +import os +import numpy as np +import matplotlib +import matplotlib.pyplot as plt +import mindspore as ms +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore import nn, ops, Tensor, Parameter, context +from mindspore.context import ParallelMode +from mindspore.communication.management import get_group_size +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +from mindspore.parallel._utils import _get_gradients_mean + +from nnmnkwii import preprocessing as P1 +from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw +from wavenet_vocoder.mixture import discretized_mix_logistic_loss +from wavenet_vocoder.mixture import mix_gaussian_loss +from train_pytorch import to_categorical +from tqdm import tqdm +import audio + +matplotlib.use('Agg') + + +def sequence_mask(sequence_length, max_len=None): + """make sequence mask""" + sequence_length = sequence_length.asnumpy() + if max_len is None: + max_len = np.max(sequence_length) + batch_size = sequence_length.shape[0] + seq_range = np.linspace(0, max_len - 1, max_len, dtype=np.int32) + seq_range_expand = np.tile(np.expand_dims(seq_range, 0), (batch_size, 1)) + seq_length_expand = np.tile(np.expand_dims(sequence_length, 1), (1, max_len)) + seq_length_expand = np.expand_dims(np.array(seq_range_expand < seq_length_expand, dtype=np.float32), -1) + return Tensor(seq_length_expand) + + +class MaskedCrossEntropyLoss(nn.Cell): + """MaskedCrossEntropyLoss""" + + def __init__(self): + super(MaskedCrossEntropyLoss, self).__init__() + self.criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + def construct(self, inputs, target): + losses = self.criterion(inputs, target) + return losses + + +class DiscretizedMixturelogisticLoss(nn.Cell): + """DiscretizedMixturelogisticLoss""" + + def __init__(self, hparams): + super(DiscretizedMixturelogisticLoss, self).__init__() + self.quantize_channels = hparams.quantize_channels + self.log_scale_min = hparams.log_scale_min + self.discretized_mix_logistic_loss = discretized_mix_logistic_loss(num_classes=hparams.quantize_channels, + log_scale_min=hparams.log_scale_min, + reduce=False) + self.reduce_sum_op = P.ReduceSum() + self.reduce_mean_op = P.ReduceMean() + + def construct(self, inputs, target, mask=None): + losses = self.discretized_mix_logistic_loss(inputs, target) + return self.reduce_sum_op(losses * mask) / self.reduce_sum_op(mask) + + +class MixtureGaussianLoss(nn.Cell): + """MixtureGaussianLoss""" + + def __init__(self, hparams): + super(MixtureGaussianLoss, self).__init__() + self.quantize_channels = hparams.quantize_channels + self.log_scale_min = hparams.log_scale_min + self.mix_gaussian_loss = mix_gaussian_loss(log_scale_min=hparams.log_scale_min, reduce=False) + self.reduce_sum_op = P.ReduceSum() + self.reduce_mean_op = P.ReduceMean() + + def construct(self, inputs, target, mask=None): + """ + + Args: + inputs (Tensor): Predicted distribution + target (Tensor): Target + mask (Tensor): Mask + + Returns: + Tensor: Loss tensor + + """ + losses = self.mix_gaussian_loss(inputs, target) + return self.reduce_sum_op(losses * mask) / self.reduce_sum_op(mask) + + +def save_waveplot(path, y_hat, y_target, sample_rate): + sr = sample_rate + plt.figure(figsize=(16, 6)) + plt.subplot(2, 1, 1) + librosa.display.waveplot(y_target, sr=sr) + plt.subplot(2, 1, 2) + librosa.display.waveplot(y_hat, sr=sr) + plt.tight_layout() + plt.savefig(path, format="png") + plt.close() + + +def eval_model(hparams, global_step, model, x, y, c, g, input_lengths, eval_dir): + """ + Function for model evaluation. This function is used for debugging in this project. + """ + + model.set_train(False) + idx = np.random.randint(0, len(y)) + length = input_lengths.asnumpy()[idx] + y_target = np.reshape(y.asnumpy()[idx], (-1)) + y_target = y_target[:length] + + if c is not None: + expand_op = P.ExpandDims() + if hparams.upsample_conditional_features: + c = expand_op(c[idx, :, :int(length // audio.get_hop_size() + hparams.cin_pad * 2)], 0) + else: + c = expand_op(c[idx, :, :length], 0) + assert c.dim() == 3 + print("Shape of local conditioning features: {}".format(c.size())) + + if g is not None: + g = g[idx] + print("Shape of global conditioning features: {}".format(g.size())) + + # Dummy silence + if is_mulaw_quantize(hparams.input_type): + initial_value = P1.mulaw_quantize(0, hparams.quantize_channels - 1) + elif is_mulaw(hparams.input_type): + initial_value = P1.mulaw(0.0, hparams.quantize_channels) + else: + initial_value = 0.0 + + if is_mulaw_quantize(hparams.input_type): + initial_input = to_categorical( + initial_value, num_classes=hparams.quantize_channels).astype(np.float32) + initial_input = Tensor(np.reshape(initial_input, (1, 1, hparams.quantize_channels))) + + else: + initial_input = np.ones((1, 1, 1)) * initial_value + initial_input = Tensor(initial_input) + + # Run the model in fast eval mode + y_hat = model.incremental_forward(initial_input, c=c, g=g, T=length, softmax=True, quantize=True, tqdm=tqdm, + log_scale_min=hparams.log_scale_min) + + if is_mulaw_quantize(hparams.input_type): + y_hat = np.reshape(np.argmax(y_hat, 1), (-1)) + y_hat = P1.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1) + y_target = P1.inv_mulaw_quantize(y_target, hparams.quantize_channels - 1) + elif is_mulaw(hparams.input_type): + y_hat = P1.inv_mulaw(np.reshape(y_hat, (-1)), hparams.quantize_channels) + y_target = P1.inv_mulaw(y_target, hparams.quantize_channels) + else: + y_hat = np.reshape(y_hat, (-1)) + + # Save audio + os.makedirs(eval_dir, exist_ok=True) + path = os.path.join(eval_dir, "step{:09d}_predicted.wav".format(global_step)) + librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) + + path = os.path.join(eval_dir, "step{:09d}_target.wav".format(global_step)) + librosa.output.write_wav(path, y_target, sr=hparams.sample_rate) + + # Save figure + path = os.path.join(eval_dir, "step{:09d}_waveplots.png".format(global_step)) + save_waveplot(path, y_hat, y_target, hparams.sample_rate) + + +class PredictNet(nn.Cell): + """ + NetWithLossClass definition + """ + + def __init__(self, network): + super(PredictNet, self).__init__(auto_prefix=False) + self.network = network + + def construct(self, x, c, g): + y_hat = self.network(x, c, g, False) + return y_hat + + +class NetWithLossClass(nn.Cell): + """ + NetWithLossClass definition + + Args: + network (Cell): Pre-defined WaveNet. + hparams (optional): Parameters. + + Returns: + Tensor, loss tensor. + """ + + def __init__(self, network, hparams): + super(NetWithLossClass, self).__init__(auto_prefix=False) + self.network = network + self.hparams = hparams + self.ReduceMean_false = P.ReduceMean(keep_dims=False) + self.expand_op = P.ExpandDims() + self.transpose_op = P.Transpose() + self.reshape_op = P.Reshape() + self.is_mulaw_quant = is_mulaw_quantize(hparams.input_type) + if self.is_mulaw_quant: + self.criterion = MaskedCrossEntropyLoss() + else: + if hparams.output_distribution == "Logistic": + self.criterion = DiscretizedMixturelogisticLoss(hparams) + elif hparams.output_distribution == "Normal": + self.criterion = MixtureGaussianLoss(hparams) + else: + self.criterion = None + raise RuntimeError( + "Not supported output distribution type: {}".format(hparams.output_distribution)) + + def construct(self, x, y, c, g, input_lengths, mask): + """ + + Args: + x (Tensor): input. + y (Tensor): prediction. + c (Tensor): Local conditioning feature. + g (Tensor): Global conditioning feature. + input_lengths(Tensor): input_lengths. + mask (Tensor): Padding mask. + + Returns: + Tensor: Loss tensor + + """ + y_hat = self.network(x, c, g, False) + if self.is_mulaw_quant: + y_hat = self.transpose_op(y_hat[:, :, :-1], (0, 2, 1)) + y_hat = self.reshape_op(y_hat, (-1, y_hat.shape[-1])) + y = self.reshape_op(y[:, 1:, 0], (-1,)) + loss = self.criterion(y_hat, y) + else: + loss = self.criterion(y_hat[:, :, :-1], y[:, 1:, :], mask[:, 1:, :]) + return loss + + +GRADIENT_CLIP_TYPE = 1 +GRADIENT_CLIP_VALUE = 5.0 +clip_grad = C.MultitypeFuncGraph("clip_grad") + + +@clip_grad.register("Number", "Number", "Tensor") +def _clip_grad(clip_type, clip_value, grad): + """ + Clip gradients. + + Inputs: + clip_type (int): The way to clip, 0 for 'value', 1 for 'norm'. + clip_value (float): Specifies how much to clip. + grad (tuple[Tensor]): Gradients. + + Outputs: + tuple[Tensor], clipped gradients. + """ + if clip_type not in [0, 1]: + return grad + dt = F.dtype(grad) + if clip_type == 0: + new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt), + F.cast(F.tuple_to_array((clip_value,)), dt)) + else: + new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt)) + return new_grad + + +grad_scale = C.MultitypeFuncGraph("grad_scale") +reciprocal = P.Reciprocal() + + +@grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale(scale, grad): + return grad * F.cast(reciprocal(scale), F.dtype(grad)) + + +_grad_overflow = C.MultitypeFuncGraph("_grad_overflow") +grad_overflow = P.FloatStatus() + + +@_grad_overflow.register("Tensor") +def _tensor_grad_overflow(grad): + return grad_overflow(grad) + + +compute_norm = C.MultitypeFuncGraph("compute_norm") + + +@compute_norm.register("Tensor") +def _compute_norm(grad): + norm = ops.norm(F.cast(grad, ms.float32)) + ret = F.expand_dims(F.cast(norm, ms.float32), 0) + return ret + + +grad_div = C.MultitypeFuncGraph("grad_div") + + +@grad_div.register("Tensor", "Tensor") +def _grad_div(val, grad): + div = P.RealDiv() + mul = P.Mul() + scale = div(1.0, val) + ret = mul(grad, scale) + return ret + + +class WaveNetTrainOneStepWithLossScaleCell(nn.Cell): + """ + WaveNet training with loss scaling. + + Args: + network (Cell): The training WaveNet. + optimizer (Cell): Optimizer for updating the weights. + scale_sense (Cell): The loss scaling update logic cell. + + Returns: + Tuple[Tensor, Tensor, Tensor], loss, overflow, sen. + """ + + def __init__(self, network, optimizer, scale_update_cell): + super(WaveNetTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.network.add_flags(defer_inline=True) + self.add_flags(has_effect=True) + self.weights = optimizer.parameters + self.optimizer = optimizer + + self.hyper_map = C.HyperMap() + self.grad = C.GradOperation(get_by_list=True, sens_param=True) + + self.sens = 1.0 + self.fill = P.Fill() + self.dtype = P.DType() + self.get_shape = P.Shape() + self.cast = P.Cast() + self.concat = P.Concat() + self.less_equal = P.LessEqual() + self.reduce_sum = P.ReduceSum(keep_dims=False) + self.scalar_summary = P.ScalarSummary() + self.greater = P.Greater() + self.select = P.Select() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_before_grad = P.NPUClearFloatStatus() + self.is_distributed = False + self.base = Tensor(1, ms.float32) + + self.all_reduce = P.AllReduce() + + self.loss_scaling_manager = scale_update_cell + self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=ms.float32)) + + self.reducer_flag = False + self.parallel_mode = context.get_auto_parallel_context("parallel_mode") + if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: + self.reducer_flag = True + self.is_distributed = True + self.grad_reducer = F.identity + self.degree = 1 + if self.reducer_flag: + self.degree = get_group_size() + mean = _get_gradients_mean() + self.grad_reducer = DistributedGradReducer(self.weights, mean, self.degree) + + def construct(self, x, y, c, g, input_lengths, mask): + """ + + Args: + x (Tensor): Source audio signal. + y (Tensor): Target audio signal. + c (Tensor): Local conditioning feature. + g (Tensor): Global conditioning feature. + input_lengths(Tensor): input_lengths + mask (Tensor): Padding mask. + + Returns: + Tuple[Tensor, Tensor, Tensor], loss, overflow, sen. + + """ + weights = self.weights + loss = self.network(x, y, c, g, input_lengths, mask) + + scale_sense = self.loss_scale + # Alloc status. + init = self.alloc_status() + init = F.depend(init, loss) + + # Clear overflow buffer. + clear_status = self.clear_before_grad(init) + scale_sense = F.depend(scale_sense, clear_status) + + grads = self.grad(self.network, weights)(x, y, c, g, input_lengths, mask, self.cast(scale_sense, ms.float32)) + grads = self.grad_reducer(grads) + grads = self.hyper_map(F.partial(grad_scale, self.degree * scale_sense), grads) + grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) + flag_sum = self.reduce_sum(init, (0,)) + + if self.is_distributed: + # Sum overflow flag over devices. + flag_reduce = self.all_reduce(flag_sum) + cond = self.less_equal(self.base, flag_reduce) + else: + cond = self.less_equal(self.base, flag_sum) + + overflow = self.loss_scaling_manager(self.loss_scale, cond) + + if overflow: + succ = False + else: + succ = self.optimizer(grads) + + self.scalar_summary("training.loss", loss) + + ret = (loss, scale_sense.value()) + return F.depend(ret, succ) diff --git a/HuaWeiExperiment/wavenet/src/lr_generator.py b/HuaWeiExperiment/wavenet/src/lr_generator.py new file mode 100644 index 00000000..183116dc --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/lr_generator.py @@ -0,0 +1,56 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Learning rate generator. +""" +import numpy as np + + +# for GPU/CPU +def get_lr(init_lr, total_epoch, step_per_epoch, + anneal_rate=0.5, + anneal_interval=100000): + """ + Learning rate generating + + Args: + init_lr (float): Initial learning rate + total_epoch (int): Total epoch + step_per_epoch (int): Step per epoch + anneal_rate (float): anneal rate + anneal_interval (int ): anneal interval + + Returns: + ndarray: learning rate + + """ + total_step = total_epoch * step_per_epoch + lr_step = [] + for i in range(total_step): + lr_step.append(init_lr * anneal_rate ** (i // anneal_interval)) + learning_rate = np.array(lr_step).astype(np.float32) + return learning_rate + + +# for Ascend +def get_lrv2(init_lr, total_epoch, step_per_epoch, + anneal_step=250): + total_step = total_epoch * step_per_epoch + lr_step = [] + + for step in range(total_step): + lambda_lr = anneal_step ** 0.5 * min((step + 1) * anneal_step ** -1.5, (step + 1) ** -0.5) + lr_step.append(init_lr * lambda_lr) + learning_rate = np.array(lr_step).astype(np.float32) + return learning_rate diff --git a/HuaWeiExperiment/wavenet/src/train_pytorch.py b/HuaWeiExperiment/wavenet/src/train_pytorch.py new file mode 100644 index 00000000..8f8e56e4 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/train_pytorch.py @@ -0,0 +1,1117 @@ +"""Trainining script for WaveNet vocoder + +usage: train.py [options] + +options: + --dump-root= Directory contains preprocessed features. + --checkpoint-dir= Directory where to save model checkpoints [default: checkpoints]. + --hparams= Hyper parameters [default: ]. + --preset= Path of preset parameters (json). + --checkpoint= Restore model from checkpoint path if given. + --restore-parts= Restore part of the model. + --log-event-path= Log event path. + --reset-optimizer Reset optimizer. + --speaker-id= Use specific speaker of data in case for multi-speaker datasets. + -h, --help Show this help message and exit +""" +from docopt import docopt + +import sys + +import os +from os.path import dirname, join, expanduser, exists +from tqdm import tqdm +from datetime import datetime +import random +import json +from glob import glob + +import numpy as np + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +import torch +from torch import nn +from torch.nn import functional as F +from torch import optim +import torch.backends.cudnn as cudnn +from torch.utils import data as data_utils +from torch.utils.data.sampler import Sampler + +import torch.optim.lr_scheduler as lrschedule + + +from nnmnkwii import preprocessing as P +from nnmnkwii.datasets import FileSourceDataset, FileDataSource + +import librosa.display + +from tensorboardX import SummaryWriter +from matplotlib import cm +from warnings import warn + +from wavenet_vocoder import WaveNet +from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_raw, is_scalar_input +from wavenet_vocoder.mixture import discretized_mix_logistic_loss +from wavenet_vocoder.mixture import sample_from_discretized_mix_logistic +from wavenet_vocoder.mixture import mix_gaussian_loss +from wavenet_vocoder.mixture import sample_from_mix_gaussian + +import audio +from hparams import hparams, hparams_debug_string + + +global_step = 0 +global_test_step = 0 +global_epoch = 0 +use_cuda = torch.cuda.is_available() +if use_cuda: + cudnn.benchmark = True + + +def sanity_check(model, c, g): + if model.has_speaker_embedding(): + if g is None: + raise RuntimeError( + "WaveNet expects speaker embedding, but speaker-id is not provided") + else: + if g is not None: + raise RuntimeError( + "WaveNet expects no speaker embedding, but speaker-id is provided") + + if model.local_conditioning_enabled(): + if c is None: + raise RuntimeError("WaveNet expects conditional features, but not given") + else: + if c is not None: + raise RuntimeError("WaveNet expects no conditional features, but given") + + +def maybe_set_epochs_based_on_max_steps(hp, steps_per_epoch): + nepochs = hp.nepochs + max_train_steps = hp.max_train_steps + if max_train_steps is not None: + epochs = int(np.ceil(max_train_steps / steps_per_epoch)) + hp.nepochs = epochs + print("info; Number of epochs is set based on max_train_steps: {}".format(epochs)) + + +def _pad(seq, max_len, constant_values=0): + return np.pad(seq, (0, max_len - len(seq)), + mode='constant', constant_values=constant_values) + + +def _pad_2d(x, max_len, b_pad=0, constant_values=0): + x = np.pad(x, [(b_pad, max_len - len(x) - b_pad), (0, 0)], + mode="constant", constant_values=constant_values) + return x + +# from: https://github.com/keras-team/keras/blob/master/keras/utils/np_utils.py +# to avoid keras dependency + + +def to_categorical(y, num_classes=None, dtype='float32'): + """Converts a class vector (integers) to binary class matrix. + E.g. for use with categorical_crossentropy. + # Arguments + y: class vector to be converted into a matrix + (integers from 0 to num_classes). + num_classes: total number of classes. + dtype: The data type expected by the input, as a string + (`float32`, `float64`, `int32`...) + # Returns + A binary matrix representation of the input. The classes axis + is placed last. + # Example + ```python + # Consider an array of 5 labels out of a set of 3 classes {0, 1, 2}: + > labels + array([0, 2, 1, 2, 0]) + # `to_categorical` converts this into a matrix with as many + # columns as there are classes. The number of rows + # stays the same. + > to_categorical(labels) + array([[ 1., 0., 0.], + [ 0., 0., 1.], + [ 0., 1., 0.], + [ 0., 0., 1.], + [ 1., 0., 0.]], dtype=float32) + ``` + """ + + y = np.array(y, dtype='int') + input_shape = y.shape + if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: + input_shape = tuple(input_shape[:-1]) + y = y.ravel() + if not num_classes: + num_classes = np.max(y) + 1 + n = y.shape[0] + categorical = np.zeros((n, num_classes), dtype=dtype) + categorical[np.arange(n), y] = 1 + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + return categorical + + +# TODO: I know this is too ugly... +class _NPYDataSource(FileDataSource): + def __init__(self, dump_root, col, typ="", speaker_id=None, max_steps=8000, + cin_pad=0, hop_size=256): + self.dump_root = dump_root + self.col = col + self.lengths = [] + self.speaker_id = speaker_id + self.multi_speaker = False + self.speaker_ids = None + self.max_steps = max_steps + self.cin_pad = cin_pad + self.hop_size = hop_size + self.typ = typ + + def collect_files(self): + meta = join(self.dump_root, "train.txt") + if not exists(meta): + paths = sorted(glob(join(self.dump_root, "*-{}.npy".format(self.typ)))) + return paths + + with open(meta, "rb") as f: + lines = f.readlines() + l = lines[0].decode("utf-8").split("|") + assert len(l) == 4 or len(l) == 5 + self.multi_speaker = len(l) == 5 + self.lengths = list( + map(lambda l: int(l.decode("utf-8").split("|")[2]), lines)) + + paths_relative = list(map(lambda l: l.decode("utf-8").split("|")[self.col], lines)) + paths = list(map(lambda f: join(self.dump_root, f), paths_relative)) + + # Exclude small files (assuming lenghts are in frame unit) + # TODO: consider this for multi-speaker + if self.max_steps is not None: + idx = np.array(self.lengths) * self.hop_size > self.max_steps + 2 * self.cin_pad * self.hop_size + if idx.sum() != len(self.lengths): + print("{} short samples are omitted for training.".format(len(self.lengths) - idx.sum())) + self.lengths = list(np.array(self.lengths)[idx]) + paths = list(np.array(paths)[idx]) + + if self.multi_speaker: + speaker_ids = list(map(lambda l: int(l.decode("utf-8").split("|")[-1]), lines)) + self.speaker_ids = speaker_ids + if self.speaker_id is not None: + # Filter by speaker_id + # using multi-speaker dataset as a single speaker dataset + indices = np.array(speaker_ids) == self.speaker_id + paths = list(np.array(paths)[indices]) + self.lengths = list(np.array(self.lengths)[indices]) + # aha, need to cast numpy.int64 to int + self.lengths = list(map(int, self.lengths)) + self.multi_speaker = False + + if self.multi_speaker: + speaker_ids_np = list(np.array(self.speaker_ids)[indices]) + self.speaker_ids = list(map(int, speaker_ids_np)) + assert len(paths) == len(self.speaker_ids) + + return paths + + def collect_features(self, path): + return np.load(path) + + +class RawAudioDataSource(_NPYDataSource): + def __init__(self, dump_root, **kwargs): + super(RawAudioDataSource, self).__init__(dump_root, 0, "wave", **kwargs) + + +class MelSpecDataSource(_NPYDataSource): + def __init__(self, dump_root, **kwargs): + super(MelSpecDataSource, self).__init__(dump_root, 1, "feats", **kwargs) + + +class PartialyRandomizedSimilarTimeLengthSampler(Sampler): + """Partially randomized sampler + + 1. Sort by lengths + 2. Pick a small patch and randomize it + 3. Permutate mini-batches + """ + + def __init__(self, lengths, batch_size=8, batch_group_size=None): + self.lengths, self.sorted_indices = torch.sort(torch.LongTensor(lengths)) + + self.batch_size = batch_size + if batch_group_size is None: + batch_group_size = min(batch_size * 8, len(self.lengths)) + if batch_group_size % batch_size != 0: + batch_group_size -= batch_group_size % batch_size + + self.batch_group_size = batch_group_size + assert batch_group_size % batch_size == 0 + + def __iter__(self): + indices = self.sorted_indices.numpy() + batch_group_size = self.batch_group_size + s, e = 0, 0 + bins = [] + for i in range(len(indices) // batch_group_size): + s = i * batch_group_size + e = s + batch_group_size + group = indices[s:e] + random.shuffle(group) + bins += [group] + + # Permutate batches + random.shuffle(bins) + binned_idx = np.stack(bins).reshape(-1) + + # Handle last elements + s += batch_group_size + if s < len(indices): + last_bin = indices[len(binned_idx):] + random.shuffle(last_bin) + binned_idx = np.concatenate([binned_idx, last_bin]) + + return iter(torch.tensor(binned_idx).long()) + + def __len__(self): + return len(self.sorted_indices) + + +class PyTorchDataset(object): + def __init__(self, X, Mel): + self.X = X + self.Mel = Mel + # alias + self.multi_speaker = X.file_data_source.multi_speaker + + def __getitem__(self, idx): + if self.Mel is None: + mel = None + else: + mel = self.Mel[idx] + + raw_audio = self.X[idx] + if self.multi_speaker: + speaker_id = self.X.file_data_source.speaker_ids[idx] + else: + speaker_id = None + + # (x,c,g) + return raw_audio, mel, speaker_id + + def __len__(self): + return len(self.X) + + +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + batch_size = sequence_length.size(0) + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.cuda() + seq_length_expand = sequence_length.unsqueeze(1) \ + .expand_as(seq_range_expand) + return (seq_range_expand < seq_length_expand).float() + + +# https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/4 +# https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage +class ExponentialMovingAverage(object): + def __init__(self, decay): + self.decay = decay + self.shadow = {} + + def register(self, name, val): + self.shadow[name] = val.clone() + + def update(self, name, x): + assert name in self.shadow + update_delta = self.shadow[name] - x + self.shadow[name] -= (1.0 - self.decay) * update_delta + + +def clone_as_averaged_model(device, model, ema): + assert ema is not None + averaged_model = build_model().to(device) + averaged_model.load_state_dict(model.state_dict()) + for name, param in averaged_model.named_parameters(): + if name in ema.shadow: + param.data = ema.shadow[name].clone() + return averaged_model + + +class MaskedCrossEntropyLoss(nn.Module): + def __init__(self): + super(MaskedCrossEntropyLoss, self).__init__() + self.criterion = nn.CrossEntropyLoss(reduction='none') + + def forward(self, input, target, lengths=None, mask=None, max_len=None): + if lengths is None and mask is None: + raise RuntimeError("Should provide either lengths or mask") + + # (B, T, 1) + if mask is None: + mask = sequence_mask(lengths, max_len).unsqueeze(-1) + + # (B, T, D) + mask_ = mask.expand_as(target) + losses = self.criterion(input, target) + return ((losses * mask_).sum()) / mask_.sum() + + +class DiscretizedMixturelogisticLoss(nn.Module): + def __init__(self): + super(DiscretizedMixturelogisticLoss, self).__init__() + + def forward(self, input, target, lengths=None, mask=None, max_len=None): + if lengths is None and mask is None: + raise RuntimeError("Should provide either lengths or mask") + + # (B, T, 1) + if mask is None: + mask = sequence_mask(lengths, max_len).unsqueeze(-1) + + # (B, T, 1) + mask_ = mask.expand_as(target) + + losses = discretized_mix_logistic_loss( + input, target, num_classes=hparams.quantize_channels, + log_scale_min=hparams.log_scale_min, reduce=False) + assert losses.size() == target.size() + return ((losses * mask_).sum()) / mask_.sum() + + +class MixtureGaussianLoss(nn.Module): + def __init__(self): + super(MixtureGaussianLoss, self).__init__() + + def forward(self, input, target, lengths=None, mask=None, max_len=None): + if lengths is None and mask is None: + raise RuntimeError("Should provide either lengths or mask") + + # (B, T, 1) + if mask is None: + mask = sequence_mask(lengths, max_len).unsqueeze(-1) + + # (B, T, 1) + mask_ = mask.expand_as(target) + + losses = mix_gaussian_loss( + input, target, log_scale_min=hparams.log_scale_min, reduce=False) + assert losses.size() == target.size() + return ((losses * mask_).sum()) / mask_.sum() + + +def ensure_divisible(length, divisible_by=256, lower=True): + if length % divisible_by == 0: + return length + if lower: + return length - length % divisible_by + else: + return length + (divisible_by - length % divisible_by) + + +def assert_ready_for_upsampling(x, c, cin_pad): + assert len(x) == (len(c) - 2 * cin_pad) * audio.get_hop_size() + + +def collate_fn(batch): + """Create batch + + Args: + batch(tuple): List of tuples + - x[0] (ndarray,int) : list of (T,) + - x[1] (ndarray,int) : list of (T, D) + - x[2] (ndarray,int) : list of (1,), speaker id + Returns: + tuple: Tuple of batch + - x (FloatTensor) : Network inputs (B, C, T) + - y (LongTensor) : Network targets (B, T, 1) + """ + + local_conditioning = len(batch[0]) >= 2 and hparams.cin_channels > 0 + global_conditioning = len(batch[0]) >= 3 and hparams.gin_channels > 0 + + if hparams.max_time_sec is not None: + max_time_steps = int(hparams.max_time_sec * hparams.sample_rate) + elif hparams.max_time_steps is not None: + max_time_steps = hparams.max_time_steps + else: + max_time_steps = None + + # Time resolution adjustment + cin_pad = hparams.cin_pad + if local_conditioning: + new_batch = [] + for idx in range(len(batch)): + x, c, g = batch[idx] + if hparams.upsample_conditional_features: + assert_ready_for_upsampling(x, c, cin_pad=0) + if max_time_steps is not None: + max_steps = ensure_divisible(max_time_steps, audio.get_hop_size(), True) + if len(x) > max_steps: + max_time_frames = max_steps // audio.get_hop_size() + s = np.random.randint(cin_pad, len(c) - max_time_frames - cin_pad) + ts = s * audio.get_hop_size() + x = x[ts:ts + audio.get_hop_size() * max_time_frames] + c = c[s - cin_pad:s + max_time_frames + cin_pad, :] + assert_ready_for_upsampling(x, c, cin_pad=cin_pad) + else: + x, c = audio.adjust_time_resolution(x, c) + if max_time_steps is not None and len(x) > max_time_steps: + s = np.random.randint(cin_pad, len(x) - max_time_steps - cin_pad) + x = x[s:s + max_time_steps] + c = c[s - cin_pad:s + max_time_steps + cin_pad, :] + assert len(x) == len(c) + new_batch.append((x, c, g)) + batch = new_batch + else: + new_batch = [] + for idx in range(len(batch)): + x, c, g = batch[idx] + x = audio.trim(x) + if max_time_steps is not None and len(x) > max_time_steps: + s = np.random.randint(0, len(x) - max_time_steps) + if local_conditioning: + x, c = x[s:s + max_time_steps], c[s:s + max_time_steps, :] + else: + x = x[s:s + max_time_steps] + new_batch.append((x, c, g)) + batch = new_batch + + # Lengths + input_lengths = [len(x[0]) for x in batch] + max_input_len = max(input_lengths) + + # (B, T, C) + # pad for time-axis + if is_mulaw_quantize(hparams.input_type): + padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1) + x_batch = np.array([_pad_2d(to_categorical( + x[0], num_classes=hparams.quantize_channels), + max_input_len, 0, padding_value) for x in batch], dtype=np.float32) + else: + x_batch = np.array([_pad_2d(x[0].reshape(-1, 1), max_input_len) + for x in batch], dtype=np.float32) + assert len(x_batch.shape) == 3 + + # (B, T) + if is_mulaw_quantize(hparams.input_type): + padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1) + y_batch = np.array([_pad(x[0], max_input_len, constant_values=padding_value) + for x in batch], dtype=np.int) + else: + y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.float32) + assert len(y_batch.shape) == 2 + + # (B, T, D) + if local_conditioning: + max_len = max([len(x[1]) for x in batch]) + c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], dtype=np.float32) + assert len(c_batch.shape) == 3 + # (B x C x T) + c_batch = torch.FloatTensor(c_batch).transpose(1, 2).contiguous() + else: + c_batch = None + + if global_conditioning: + g_batch = torch.LongTensor([x[2] for x in batch]) + else: + g_batch = None + + # Covnert to channel first i.e., (B, C, T) + x_batch = torch.FloatTensor(x_batch).transpose(1, 2).contiguous() + # Add extra axis + if is_mulaw_quantize(hparams.input_type): + y_batch = torch.LongTensor(y_batch).unsqueeze(-1).contiguous() + else: + y_batch = torch.FloatTensor(y_batch).unsqueeze(-1).contiguous() + + input_lengths = torch.LongTensor(input_lengths) + + return x_batch, y_batch, c_batch, g_batch, input_lengths + + +def time_string(): + return datetime.now().strftime('%Y-%m-%d %H:%M') + + +def save_waveplot(path, y_hat, y_target): + sr = hparams.sample_rate + + plt.figure(figsize=(16, 6)) + plt.subplot(2, 1, 1) + librosa.display.waveplot(y_target, sr=sr) + plt.subplot(2, 1, 2) + librosa.display.waveplot(y_hat, sr=sr) + plt.tight_layout() + plt.savefig(path, format="png") + plt.close() + + +def eval_model(global_step, writer, device, model, y, c, g, input_lengths, eval_dir, ema=None): + if ema is not None: + print("Using averaged model for evaluation") + model = clone_as_averaged_model(device, model, ema) + model.make_generation_fast_() + + model.eval() + idx = np.random.randint(0, len(y)) + length = input_lengths[idx].data.cpu().item() + + # (T,) + y_target = y[idx].view(-1).data.cpu().numpy()[:length] + + if c is not None: + if hparams.upsample_conditional_features: + c = c[idx, :, :length // audio.get_hop_size() + hparams.cin_pad * 2].unsqueeze(0) + else: + c = c[idx, :, :length].unsqueeze(0) + assert c.dim() == 3 + print("Shape of local conditioning features: {}".format(c.size())) + if g is not None: + # TODO: test + g = g[idx] + print("Shape of global conditioning features: {}".format(g.size())) + + # Dummy silence + if is_mulaw_quantize(hparams.input_type): + initial_value = P.mulaw_quantize(0, hparams.quantize_channels - 1) + elif is_mulaw(hparams.input_type): + initial_value = P.mulaw(0.0, hparams.quantize_channels) + else: + initial_value = 0.0 + + # (C,) + if is_mulaw_quantize(hparams.input_type): + initial_input = to_categorical( + initial_value, num_classes=hparams.quantize_channels).astype(np.float32) + initial_input = torch.from_numpy(initial_input).view( + 1, 1, hparams.quantize_channels) + else: + initial_input = torch.zeros(1, 1, 1).fill_(initial_value) + initial_input = initial_input.to(device) + + # Run the model in fast eval mode + with torch.no_grad(): + y_hat = model.incremental_forward( + initial_input, c=c, g=g, T=length, softmax=True, quantize=True, tqdm=tqdm, + log_scale_min=hparams.log_scale_min) + + if is_mulaw_quantize(hparams.input_type): + y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy() + y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1) + y_target = P.inv_mulaw_quantize(y_target, hparams.quantize_channels - 1) + elif is_mulaw(hparams.input_type): + y_hat = P.inv_mulaw(y_hat.view(-1).cpu().data.numpy(), hparams.quantize_channels) + y_target = P.inv_mulaw(y_target, hparams.quantize_channels) + else: + y_hat = y_hat.view(-1).cpu().data.numpy() + + # Save audio + os.makedirs(eval_dir, exist_ok=True) + path = join(eval_dir, "step{:09d}_predicted.wav".format(global_step)) + librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) + path = join(eval_dir, "step{:09d}_target.wav".format(global_step)) + librosa.output.write_wav(path, y_target, sr=hparams.sample_rate) + + # save figure + path = join(eval_dir, "step{:09d}_waveplots.png".format(global_step)) + save_waveplot(path, y_hat, y_target) + + +def save_states(global_step, writer, y_hat, y, input_lengths, checkpoint_dir=None): + print("Save intermediate states at step {}".format(global_step)) + idx = np.random.randint(0, len(y_hat)) + length = input_lengths[idx].data.cpu().item() + + # (B, C, T) + if y_hat.dim() == 4: + y_hat = y_hat.squeeze(-1) + + if is_mulaw_quantize(hparams.input_type): + # (B, T) + y_hat = F.softmax(y_hat, dim=1).max(1)[1] + + # (T,) + y_hat = y_hat[idx].data.cpu().long().numpy() + y = y[idx].view(-1).data.cpu().long().numpy() + + y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1) + y = P.inv_mulaw_quantize(y, hparams.quantize_channels - 1) + else: + # (B, T) + if hparams.output_distribution == "Logistic": + y_hat = sample_from_discretized_mix_logistic( + y_hat, log_scale_min=hparams.log_scale_min) + elif hparams.output_distribution == "Normal": + y_hat = sample_from_mix_gaussian( + y_hat, log_scale_min=hparams.log_scale_min) + else: + assert False + + # (T,) + y_hat = y_hat[idx].view(-1).data.cpu().numpy() + y = y[idx].view(-1).data.cpu().numpy() + + if is_mulaw(hparams.input_type): + y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels) + y = P.inv_mulaw(y, hparams.quantize_channels) + + # Mask by length + y_hat[length:] = 0 + y[length:] = 0 + + # Save audio + audio_dir = join(checkpoint_dir, "intermediate", "audio") + os.makedirs(audio_dir, exist_ok=True) + path = join(audio_dir, "step{:09d}_predicted.wav".format(global_step)) + librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) + path = join(audio_dir, "step{:09d}_target.wav".format(global_step)) + librosa.output.write_wav(path, y, sr=hparams.sample_rate) + +# workaround for https://github.com/pytorch/pytorch/issues/15716 +# the idea is to return outputs and replicas explicitly, so that making pytorch +# not to release the nodes (this is a pytorch bug though) + + +def data_parallel_workaround(model, input): + device_ids = list(range(torch.cuda.device_count())) + output_device = device_ids[0] + replicas = torch.nn.parallel.replicate(model, device_ids) + inputs = torch.nn.parallel.scatter(input, device_ids) + replicas = replicas[:len(inputs)] + outputs = torch.nn.parallel.parallel_apply(replicas, inputs) + y_hat = torch.nn.parallel.gather(outputs, output_device) + return y_hat, outputs, replicas + + +def __train_step(device, phase, epoch, global_step, global_test_step, + model, optimizer, writer, criterion, + x, y, c, g, input_lengths, + checkpoint_dir, eval_dir=None, do_eval=False, ema=None): + sanity_check(model, c, g) + + # x : (B, C, T) + # y : (B, T, 1) + # c : (B, C, T) + # g : (B,) + train = (phase == "train_no_dev") + clip_thresh = hparams.clip_thresh + if train: + model.train() + step = global_step + else: + model.eval() + step = global_test_step + + # Learning rate schedule + current_lr = hparams.optimizer_params["lr"] + if train and hparams.lr_schedule is not None: + lr_schedule_f = getattr(lrschedule, hparams.lr_schedule) + current_lr = lr_schedule_f( + hparams.optimizer_params["lr"], step, **hparams.lr_schedule_kwargs) + for param_group in optimizer.param_groups: + param_group['lr'] = current_lr + optimizer.zero_grad() + + # Prepare data + x, y = x.to(device), y.to(device) + input_lengths = input_lengths.to(device) + c = c.to(device) if c is not None else None + g = g.to(device) if g is not None else None + + # (B, T, 1) + mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1) + mask = mask[:, 1:, :] + + # Apply model: Run the model in regular eval mode + # NOTE: softmax is handled in F.cross_entrypy_loss + # y_hat: (B x C x T) + + if use_cuda: + # multi gpu support + # you must make sure that batch size % num gpu == 0 + y_hat, _outputs, _replicas = data_parallel_workaround(model, (x, c, g, False)) + else: + y_hat = model(x, c, g, False) + + if is_mulaw_quantize(hparams.input_type): + # wee need 4d inputs for spatial cross entropy loss + # (B, C, T, 1) + y_hat = y_hat.unsqueeze(-1) + loss = criterion(y_hat[:, :, :-1, :], y[:, 1:, :], mask=mask) + else: + loss = criterion(y_hat[:, :, :-1], y[:, 1:, :], mask=mask) + + if train and step > 0 and step % hparams.checkpoint_interval == 0: + save_states(step, writer, y_hat, y, input_lengths, checkpoint_dir) + save_checkpoint(device, model, optimizer, step, checkpoint_dir, epoch, ema) + + if do_eval: + # NOTE: use train step (i.e., global_step) for filename + eval_model(global_step, writer, device, model, y, c, g, input_lengths, eval_dir, ema) + + # Update + if train: + loss.backward() + if clip_thresh > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_thresh) + optimizer.step() + # update moving average + if ema is not None: + for name, param in model.named_parameters(): + if name in ema.shadow: + ema.update(name, param.data) + + # Logs + writer.add_scalar("{} loss".format(phase), float(loss.item()), step) + if train: + if clip_thresh > 0: + writer.add_scalar("gradient norm", grad_norm, step) + writer.add_scalar("learning rate", current_lr, step) + + return loss.item() + + +def train_loop(device, model, data_loaders, optimizer, writer, checkpoint_dir=None): + if is_mulaw_quantize(hparams.input_type): + criterion = MaskedCrossEntropyLoss() + else: + if hparams.output_distribution == "Logistic": + criterion = DiscretizedMixturelogisticLoss() + elif hparams.output_distribution == "Normal": + criterion = MixtureGaussianLoss() + else: + raise RuntimeError( + "Not supported output distribution type: {}".format( + hparams.output_distribution)) + + if hparams.exponential_moving_average: + ema = ExponentialMovingAverage(hparams.ema_decay) + for name, param in model.named_parameters(): + if param.requires_grad: + ema.register(name, param.data) + else: + ema = None + + global global_step, global_epoch, global_test_step + while global_epoch < hparams.nepochs: + for phase, data_loader in data_loaders.items(): + train = (phase == "train_no_dev") + running_loss = 0. + test_evaluated = False + for step, (x, y, c, g, input_lengths) in tqdm(enumerate(data_loader)): + # Whether to save eval (i.e., online decoding) result + do_eval = False + eval_dir = join(checkpoint_dir, "intermediate", "{}_eval".format(phase)) + # Do eval per eval_interval for train + if train and global_step > 0 \ + and global_step % hparams.train_eval_interval == 0: + do_eval = True + # Do eval for test + # NOTE: Decoding WaveNet is quite time consuming, so + # do only once in a single epoch for testset + if not train and not test_evaluated \ + and global_epoch % hparams.test_eval_epoch_interval == 0: + do_eval = True + test_evaluated = True + if do_eval: + print("[{}] Eval at train step {}".format(phase, global_step)) + + # Do step + running_loss += __train_step(device, + phase, global_epoch, global_step, global_test_step, model, + optimizer, writer, criterion, x, y, c, g, input_lengths, + checkpoint_dir, eval_dir, do_eval, ema) + + # update global state + if train: + global_step += 1 + else: + global_test_step += 1 + + if global_step >= hparams.max_train_steps: + print("Training reached max train steps ({}). will exit".format(hparams.max_train_steps)) + return ema + + # log per epoch + averaged_loss = running_loss / len(data_loader) + writer.add_scalar("{} loss (per epoch)".format(phase), + averaged_loss, global_epoch) + print("Step {} [{}] Loss: {}".format( + global_step, phase, running_loss / len(data_loader))) + + global_epoch += 1 + return ema + + +def save_checkpoint(device, model, optimizer, step, checkpoint_dir, epoch, ema=None): + checkpoint_path = join( + checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step)) + optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None + global global_test_step + torch.save({ + "state_dict": model.state_dict(), + "optimizer": optimizer_state, + "global_step": step, + "global_epoch": epoch, + "global_test_step": global_test_step, + }, checkpoint_path) + print("Saved checkpoint:", checkpoint_path) + + import shutil + latest_pth = join(checkpoint_dir, "checkpoint_latest.pth") + shutil.copyfile(checkpoint_path, latest_pth) + + if ema is not None: + averaged_model = clone_as_averaged_model(device, model, ema) + checkpoint_path = join( + checkpoint_dir, "checkpoint_step{:09d}_ema.pth".format(global_step)) + torch.save({ + "state_dict": averaged_model.state_dict(), + "optimizer": optimizer_state, + "global_step": step, + "global_epoch": epoch, + "global_test_step": global_test_step, + }, checkpoint_path) + print("Saved averaged checkpoint:", checkpoint_path) + + latest_pth = join(checkpoint_dir, "checkpoint_latest_ema.pth") + shutil.copyfile(checkpoint_path, latest_pth) + + +def build_model(): + if is_mulaw_quantize(hparams.input_type): + if hparams.out_channels != hparams.quantize_channels: + raise RuntimeError( + "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'") + if hparams.upsample_conditional_features and hparams.cin_channels < 0: + s = "Upsample conv layers were specified while local conditioning disabled. " + s += "Notice that upsample conv layers will never be used." + warn(s) + + upsample_params = hparams.upsample_params + upsample_params["cin_channels"] = hparams.cin_channels + upsample_params["cin_pad"] = hparams.cin_pad + model = WaveNet( + out_channels=hparams.out_channels, + layers=hparams.layers, + stacks=hparams.stacks, + residual_channels=hparams.residual_channels, + gate_channels=hparams.gate_channels, + skip_out_channels=hparams.skip_out_channels, + cin_channels=hparams.cin_channels, + gin_channels=hparams.gin_channels, + n_speakers=hparams.n_speakers, + dropout=hparams.dropout, + kernel_size=hparams.kernel_size, + cin_pad=hparams.cin_pad, + upsample_conditional_features=hparams.upsample_conditional_features, + upsample_params=upsample_params, + scalar_input=is_scalar_input(hparams.input_type), + output_distribution=hparams.output_distribution, + ) + return model + + +def _load(checkpoint_path): + if use_cuda: + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + + +def load_checkpoint(path, model, optimizer, reset_optimizer): + global global_step + global global_epoch + global global_test_step + + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + model.load_state_dict(checkpoint["state_dict"]) + if not reset_optimizer: + optimizer_state = checkpoint["optimizer"] + if optimizer_state is not None: + print("Load optimizer state from {}".format(path)) + optimizer.load_state_dict(checkpoint["optimizer"]) + global_step = checkpoint["global_step"] + global_epoch = checkpoint["global_epoch"] + global_test_step = checkpoint.get("global_test_step", 0) + + return model + + +# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3 +def restore_parts(path, model): + print("Restore part of the model from: {}".format(path)) + state = _load(path)["state_dict"] + model_dict = model.state_dict() + valid_state_dict = {k: v for k, v in state.items() if k in model_dict} + + try: + model_dict.update(valid_state_dict) + model.load_state_dict(model_dict) + except RuntimeError as e: + # there should be invalid size of weight(s), so load them per parameter + print(str(e)) + model_dict = model.state_dict() + for k, v in valid_state_dict.items(): + model_dict[k] = v + try: + model.load_state_dict(model_dict) + except RuntimeError as e: + print(str(e)) + warn("{}: may contain invalid size of weight. skipping...".format(k)) + + +def get_data_loaders(dump_root, speaker_id, test_shuffle=True): + data_loaders = {} + local_conditioning = hparams.cin_channels > 0 + + if hparams.max_time_steps is not None: + max_steps = ensure_divisible(hparams.max_time_steps, audio.get_hop_size(), True) + else: + max_steps = None + + for phase in ["train_no_dev", "dev"]: + train = phase == "train_no_dev" + X = FileSourceDataset( + RawAudioDataSource(join(dump_root, phase), speaker_id=speaker_id, + max_steps=max_steps, cin_pad=hparams.cin_pad, + hop_size=audio.get_hop_size())) + if local_conditioning: + Mel = FileSourceDataset( + MelSpecDataSource(join(dump_root, phase), speaker_id=speaker_id, + max_steps=max_steps, cin_pad=hparams.cin_pad, + hop_size=audio.get_hop_size())) + assert len(X) == len(Mel) + print("Local conditioning enabled. Shape of a sample: {}.".format( + Mel[0].shape)) + else: + Mel = None + print("[{}]: length of the dataset is {}".format(phase, len(X))) + + if train: + lengths = np.array(X.file_data_source.lengths) + # Prepare sampler + sampler = PartialyRandomizedSimilarTimeLengthSampler( + lengths, batch_size=hparams.batch_size) + shuffle = False + # make sure that there's no sorting bugs for https://github.com/r9y9/wavenet_vocoder/issues/130 + sampler_idx = np.asarray(sorted(list(map(lambda s: int(s), sampler)))) + assert (sampler_idx == np.arange(len(sampler_idx), dtype=np.int)).all() + else: + sampler = None + shuffle = test_shuffle + + dataset = PyTorchDataset(X, Mel) + data_loader = data_utils.DataLoader( + dataset, batch_size=hparams.batch_size, drop_last=True, + num_workers=hparams.num_workers, sampler=sampler, shuffle=shuffle, + collate_fn=collate_fn, pin_memory=hparams.pin_memory) + + speaker_ids = {} + if X.file_data_source.multi_speaker: + for idx, (x, c, g) in enumerate(dataset): + if g is not None: + try: + speaker_ids[g] += 1 + except KeyError: + speaker_ids[g] = 1 + if len(speaker_ids) > 0: + print("Speaker stats:", speaker_ids) + + data_loaders[phase] = data_loader + + return data_loaders + + +if __name__ == "__main__": + args = docopt(__doc__) + print("Command line args:\n", args) + checkpoint_dir = args["--checkpoint-dir"] + checkpoint_path = args["--checkpoint"] + checkpoint_restore_parts = args["--restore-parts"] + speaker_id = args["--speaker-id"] + speaker_id = int(speaker_id) if speaker_id is not None else None + preset = args["--preset"] + + dump_root = args["--dump-root"] + if dump_root is None: + dump_root = join(dirname(__file__), "data", "ljspeech") + + log_event_path = args["--log-event-path"] + reset_optimizer = args["--reset-optimizer"] + + # Load preset if specified + if preset is not None: + with open(preset) as f: + hparams.parse_json(f.read()) + # Override hyper parameters + hparams.parse(args["--hparams"]) + assert hparams.name == "wavenet_vocoder" + print(hparams_debug_string()) + + fs = hparams.sample_rate + + os.makedirs(checkpoint_dir, exist_ok=True) + + output_json_path = join(checkpoint_dir, "hparams.json") + with open(output_json_path, "w") as f: + json.dump(hparams.values(), f, indent=2) + + # Dataloader setup + data_loaders = get_data_loaders(dump_root, speaker_id, test_shuffle=True) + + maybe_set_epochs_based_on_max_steps(hparams, len(data_loaders["train_no_dev"])) + + device = torch.device("cuda" if use_cuda else "cpu") + + # Model + model = build_model().to(device) + + receptive_field = model.receptive_field + print("Receptive field (samples / ms): {} / {}".format( + receptive_field, receptive_field / fs * 1000)) + + from torch import optim + Optimizer = getattr(optim, hparams.optimizer) + optimizer = Optimizer(model.parameters(), **hparams.optimizer_params) + + if checkpoint_restore_parts is not None: + restore_parts(checkpoint_restore_parts, model) + + # Load checkpoints + if checkpoint_path is not None: + load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer) + + # Setup summary writer for tensorboard + if log_event_path is None: + log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_") + print("TensorBoard event log path: {}".format(log_event_path)) + writer = SummaryWriter(log_dir=log_event_path) + + # Train! + ema = None + try: + ema = train_loop(device, model, data_loaders, optimizer, writer, + checkpoint_dir=checkpoint_dir) + except KeyboardInterrupt: + print("Interrupted!") + pass + finally: + save_checkpoint( + device, model, optimizer, global_step, checkpoint_dir, global_epoch, ema) + + print("Finished") + + sys.exit(0) diff --git a/HuaWeiExperiment/wavenet/src/wavenet_vocoder/__init__.py b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/__init__.py new file mode 100644 index 00000000..e34d0b96 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from __future__ import with_statement, print_function, absolute_import +from .wavenet import WaveNet diff --git a/HuaWeiExperiment/wavenet/src/wavenet_vocoder/conv.py b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/conv.py new file mode 100644 index 00000000..85feed02 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/conv.py @@ -0,0 +1,182 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Extended Conv1D.""" + +import math +import numpy as np +from mindspore import nn, Tensor +from mindspore.ops import operations as P +import mindspore.common.dtype as mstype +from mindspore import context + +class Conv1d(nn.Conv1d): + """ + Extended nn.Conv1d to adapt to incremental dilated convolutions. + During training, initial Conv1D is used and during evaluation, incremental_forward is called. + To improve the inference speed, tensor will be converted as numpy and the following calculation is based on numpy. + These operation will be replaced with MindSpore ops in the future. Currently, some operation is not supported by + MindSpore and a mixed use of numpy and MindSpore will take a long time. + + """ + + def __init__(self, *args, **kwargs): + super(Conv1d, self).__init__(*args, **kwargs) + self.clear_buffer() + self._linearized_weight = None + self.transpose_op = P.Transpose() + self.reshape_op = P.Reshape() + self.squeeze_op = P.Squeeze(-2) + self.zeros = P.Zeros() + self.concat_op = P.Concat(axis=1) + self.matmul = P.MatMul(transpose_b=True) + self.bias_add = P.BiasAdd() + self.get_weight = None + self.get_bias = None + + def incremental_forward(self, inputs, is_numpy=True): + if is_numpy: + return self.incremental_forward_numpy(inputs) + return self.incremental_forward_pynative(inputs) + + def incremental_forward_pynative(self, inputs): + """ + Incremental forward. + + Args: + inputs: B x T x C + + Returns: + ndarray + + """ + # input: (B, T, C) + if self.training: + raise RuntimeError('incremental_forward only supports eval mode') + + if self.get_weight is None: + self.get_weight = self._get_linearized_weight() + + if self.get_bias is None and self.bias is not None: + self.get_bias = self.bias + + # Note mindspore uses Conv2D to construct Conv1D + kw = self.kernel_size[1] + dilation = self.dilation[1] + + bsz = inputs.shape[0] # input: bsz x len x dim + if kw > 1: + if self.input_buffer is None: + init_buffer = self.zeros((bsz, kw + (kw - 1) * (dilation - 1), inputs.shape[2]), mstype.float32) + self.input_buffer = self.concat_op((init_buffer[:, 1:, :], inputs[:, 0:1, :])) + else: + # shift buffer + self.input_buffer = self.concat_op((self.input_buffer[:, 1:, :], inputs[:, 0:1, :])) + inputs = self.input_buffer + if dilation > 1: + if context.get_context("device_target") == "CPU": + inputs = self.transpose_op(inputs, (1, 0, 2)) + inputs = inputs[0::dilation, :, :] + inputs = self.transpose_op(inputs, (1, 0, 2)) + else: + inputs = inputs[:, 0::dilation, :] + + output = self.matmul(self.reshape_op(inputs, (bsz, -1)), self.get_weight) + if self.bias is not None: + output = self.bias_add(output, self.bias) + return self.reshape_op(output, (bsz, 1, -1)) + + def incremental_forward_numpy(self, inputs): + """ + Incremental forward. + + Args: + inputs: B x T x C + + Returns: + ndarray + + """ + # input: (B, T, C) + if self.training: + raise RuntimeError('incremental_forward only supports eval mode') + + if self.get_weight is None: + weight = self._get_linearized_weight() + self.get_weight = weight.asnumpy() + + if self.get_bias is None and self.bias is not None: + bias = self.bias + self.get_bias = bias.asnumpy() + + # Note mindspore uses Conv2D to construct Conv1D + kw = self.kernel_size[1] + dilation = self.dilation[1] + + bsz = inputs.shape[0] # input: bsz x len x dim + if kw > 1: + if self.input_buffer is None: + self.input_buffer = np.zeros((bsz, kw + (kw - 1) * (dilation - 1), inputs.shape[2]), dtype=np.float32) + else: + # shift buffer + self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :] + # append next + self.input_buffer[:, -1, :] = inputs[:, -1, :] + inputs = self.input_buffer + if dilation > 1: + inputs = inputs[:, 0::dilation, :] + output = inputs.reshape(bsz, -1).dot(self.get_weight.T) + if self.bias is not None: + output = output + np.expand_dims(self.get_bias, 0) + return np.reshape(output, (bsz, 1, -1)) + + def clear_buffer(self): + self.input_buffer = None + + def _get_linearized_weight(self): + """ + get linearized weight + """ + weight = self.squeeze_op(self.weight) + if self._linearized_weight is None: + # Note mindspore uses Conv2D to construct Conv1D + kw = self.kernel_size[1] + if weight.shape == (self.out_channels, self.in_channels, kw): + weight = self.transpose_op(weight, (0, 2, 1)) + else: + weight = self.transpose_op(weight, (2, 0, 1)) + self._linearized_weight = self.reshape_op(weight, (self.out_channels, -1)) + return self._linearized_weight + + def _clear_linearized_weight(self, *args): + self._linearized_weight = None + + def _initialize_weights(self): + """ + weight initialization + """ + self.init_parameters_data() + std_mul = 4.0 + for _, m in self.cells_and_names(): + if isinstance(m, nn.Conv1d): + std = math.sqrt((std_mul * 0.1) / (m.kernel_size[1] * self.in_channels)) + m.weight.set_data(Tensor(np.random.normal(0, std, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_data( + Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_data( + Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) diff --git a/HuaWeiExperiment/wavenet/src/wavenet_vocoder/mixture.py b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/mixture.py new file mode 100644 index 00000000..a594fdd4 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/mixture.py @@ -0,0 +1,386 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Loss function for training and sample function for testing. +""" +import numpy as np +import mindspore as ms +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.ops as P +from mindspore import context + + +class log_sum_exp(nn.Cell): + """Numerically stable log_sum_exp + """ + + def __init__(self): + super(log_sum_exp, self).__init__() + self.maxi = P.ReduceMax() + self.maxi_dim = P.ReduceMax(keep_dims=True) + self.log = P.Log() + self.sums = P.ReduceSum() + self.exp = P.Exp() + + def construct(self, x): + axis = len(x.shape) - 1 + m = self.maxi(x, axis) + m2 = self.maxi_dim(x, axis) + return m + self.log(self.sums(self.exp(x - m2), axis)) + + +class log_softmax(nn.Cell): + """ + replacement of P.LogSoftmax(-1) in CPU mode + only support x.shape == 2 or 3 + """ + + def __init__(self): + super(log_softmax, self).__init__() + self.maxi = P.ReduceMax() + self.log = P.Log() + self.sums = P.ReduceSum() + self.exp = P.Exp() + self.axis = -1 + self.concat = P.Concat(-1) + self.expanddims = P.ExpandDims() + + def construct(self, x): + """ + + Args: + x (Tensor): input + + Returns: + Tensor: log_softmax of input + + """ + c = self.maxi(x, self.axis) + logs, lsm = None, None + if len(x.shape) == 2: + for j in range(x.shape[-1]): + temp = self.expanddims(self.exp(x[:, j] - c), -1) + logs = temp if j == 0 else self.concat((logs, temp)) + sums = self.sums(logs, -1) + for i in range(x.shape[-1]): + temp = self.expanddims(x[:, i] - c - self.log(sums), -1) + lsm = temp if i == 0 else self.concat((lsm, temp)) + return lsm + if len(x.shape) == 3: + for j in range(x.shape[-1]): + temp = self.expanddims(self.exp(x[:, :, j] - c), -1) + logs = temp if j == 0 else self.concat((logs, temp)) + sums = self.sums(logs, -1) + for i in range(x.shape[-1]): + temp = self.expanddims(x[:, :, i] - c - self.log(sums), -1) + lsm = temp if i == 0 else self.concat((lsm, temp)) + return lsm + return None + + +class Stable_softplus(nn.Cell): + """Numerically stable softplus + """ + + def __init__(self): + super(Stable_softplus, self).__init__() + self.log_op = P.Log() + self.abs_op = P.Abs() + self.relu_op = P.ReLU() + self.exp_op = P.Exp() + + def construct(self, x): + return self.log_op(1 + self.exp_op(- self.abs_op(x))) + self.relu_op(x) + + +class discretized_mix_logistic_loss(nn.Cell): + """ + Discretized_mix_logistic_loss + + Args: + num_classes (int): Num_classes + log_scale_min (float): Log scale minimum value + + """ + + def __init__(self, num_classes=256, log_scale_min=-7.0, reduce=True): + super(discretized_mix_logistic_loss, self).__init__() + self.num_classes = num_classes + self.log_scale_min = log_scale_min + self.reduce = reduce + self.transpose_op = P.Transpose() + self.exp = P.Exp() + self.sigmoid = P.Sigmoid() + self.softplus = Stable_softplus() + self.log = P.Log() + self.cast = P.Cast() + self.expand_dims = P.ExpandDims() + self.tile = P.Tile() + self.maximum = P.Maximum() + self.sums = P.ReduceSum() + self.lse = log_sum_exp() + self.reshape = P.Reshape() + self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32)) + self.tensor_one = Tensor(1., ms.float32) + + if context.get_context("device_target") == "CPU": + self.logsoftmax = log_softmax() + else: + self.logsoftmax = P.LogSoftmax(-1) + + def construct(self, y_hat, y): + """ + + Args: + y_hat (Tensor): Predicted distribution + y (Tensor): Target + + Returns: + Tensor: Discretized_mix_logistic_loss + + """ + nr_mix = y_hat.shape[1] // 3 + + # (B x T x C) + y_hat = self.transpose_op(y_hat, (0, 2, 1)) + + # (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix:2 * nr_mix] + min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix)) + log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut) + + # B x T x 1 -> B x T x num_mixtures + y = self.tile(y, (1, 1, nr_mix)) + + centered_y = y - means + inv_stdv = self.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1. / (self.num_classes - 1)) + cdf_plus = self.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1. / (self.num_classes - 1)) + cdf_min = self.sigmoid(min_in) + + log_cdf_plus = plus_in - self.softplus(plus_in) + + log_one_minus_cdf_min = -self.softplus(min_in) + + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + log_pdf_mid = mid_in - log_scales - 2. * self.softplus(mid_in) + + inner_inner_cond = self.cast(cdf_delta > 1e-5, ms.float32) + min_cut2 = 1e-12 * self.tile(self.tensor_one, cdf_delta.shape) + inner_inner_out = inner_inner_cond * \ + self.log(self.maximum(cdf_delta, min_cut2)) + \ + (1. - inner_inner_cond) * (log_pdf_mid - self.factor) + inner_cond = self.cast(y > 0.999, ms.float32) + inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out + cond = self.cast(y < -0.999, ms.float32) + log_probs = cond * log_cdf_plus + (1. - cond) * inner_out + + a, b, c = logit_probs.shape[0], logit_probs.shape[1], logit_probs.shape[2] + logit_probs = self.logsoftmax(self.reshape(logit_probs, (-1, c))) + logit_probs = self.reshape(logit_probs, (a, b, c)) + + log_probs = log_probs + logit_probs + if self.reduce: + return -self.sums(self.lse(log_probs)) + return self.expand_dims(-self.lse(log_probs), -1) + + +def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0): + """ + Sample from discretized mixture of logistic distributions + + Args: + y (ndarray): B x C x T + log_scale_min (float): Log scale minimum value + + Returns: + ndarray + """ + nr_mix = y.shape[1] // 3 + + # B x T x C + y = np.transpose(y, (0, 2, 1)) + logit_probs = y[:, :, :nr_mix] + + temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape) + temp = logit_probs - np.log(- np.log(temp)) + + argmax = np.argmax(temp, axis=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = np.eye(nr_mix)[argmax] + means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1) + log_scales = np.clip(np.sum( + y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1), a_min=log_scale_min, a_max=None) + + u = np.random.uniform(1e-5, 1.0 - 1e-5, means.shape) + x = means + np.exp(log_scales) * (np.log(u) - np.log(1. - u)) + x = np.clip(x, -1., 1.) + return x.astype(np.float32) + + +class mix_gaussian_loss(nn.Cell): + """ + Mix gaussian loss + """ + + def __init__(self, log_scale_min=-7.0, reduce=True): + super(mix_gaussian_loss, self).__init__() + self.log_scale_min = log_scale_min + self.reduce = reduce + self.transpose_op = P.Transpose() + self.maximum = P.Maximum() + self.tile = P.Tile() + self.exp = P.Exp() + self.expand_dims = P.ExpandDims() + self.sums = P.ReduceSum() + self.lse = log_sum_exp() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.const = P.ScalarToTensor() + self.log = P.Log() + self.tensor_one = Tensor(1., ms.float32) + + if context.get_context("device_target") == "CPU": + self.logsoftmax = log_softmax() + else: + self.logsoftmax = P.LogSoftmax(-1) + + def construct(self, y_hat, y): + """ + + Args: + y_hat (Tensor): Predicted probability + y (Tensor): Target + + Returns: + Tensor: Mix_gaussian_loss + + """ + C = y_hat.shape[1] + if C == 2: + nr_mix = 1 + else: + nr_mix = y_hat.shape[1] // 3 + + # (B x T x C) + y_hat = self.transpose_op(y_hat, (0, 2, 1)) + + if C == 2: + logit_probs = None + means = y_hat[:, :, 0:1] + min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], 1)) + log_scales = self.maximum(y_hat[:, :, 1:2], min_cut) + else: + # (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix:2 * nr_mix] + min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix)) + log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut) + + # B x T x 1 -> B x T x num_mixtures + y = self.tile(y, (1, 1, nr_mix)) + centered_y = y - means + + sd = self.exp(log_scales) + unnormalized_log_prob = -1. * (self.sq(centered_y - 0.)) / (2. * self.sq(sd)) + neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) + log_probs = unnormalized_log_prob + neg_normalization + + if nr_mix > 1: + log_probs = log_probs + self.logsoftmax(logit_probs) + + if self.reduce: + if nr_mix == 1: + return -self.sums(log_probs) + return -self.sums(self.lse(log_probs)) + if nr_mix == 1: + return -log_probs + return self.expand_dims(-self.lse(log_probs), -1) + + +def sample_from_mix_gaussian(y, log_scale_min=-7.0): + """ + Sample_from_mix_gaussian + + Args: + y (ndarray): B x C x T + + Returns: + ndarray + + """ + C = y.shape[1] + if C == 2: + nr_mix = 1 + else: + nr_mix = y.shape[1] // 3 + + # B x T x C + y = np.transpose(y, (0, 2, 1)) + + if C == 2: + logit_probs = None + else: + logit_probs = y[:, :, :nr_mix] + + if nr_mix > 1: + temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape) + temp = logit_probs - np.log(- np.log(temp)) + argmax = np.argmax(temp, axis=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = np.eye(nr_mix)[argmax] + + means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1) + + log_scales = np.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1) + else: + if C == 2: + means, log_scales = y[:, :, 0], y[:, :, 1] + elif C == 3: + means, log_scales = y[:, :, 1], y[:, :, 2] + else: + assert False, "shouldn't happen" + + scales = np.exp(log_scales) + x = np.random.normal(loc=means, scale=scales) + x = np.clip(x, -1., 1.) + return x.astype(np.float32) + + +# self-implemented onehotcategorical distribution +# https://zhuanlan.zhihu.com/p/59550457 +def sample_from_mix_onehotcategorical(x): + """ + Sample_from_mix_onehotcategorical + + Args: + x (ndarray): Predicted softmax probability + + Returns: + ndarray + + """ + pi = np.log(x) + u = np.random.uniform(0, 1, x.shape) + g = -np.log(-np.log(u)) + c = np.argmax(pi + g, axis=1) + return np.array(np.eye(256)[c], dtype=np.float32) diff --git a/HuaWeiExperiment/wavenet/src/wavenet_vocoder/modules.py b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/modules.py new file mode 100644 index 00000000..208049c7 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/modules.py @@ -0,0 +1,213 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Modules for WaveNet. +""" +from __future__ import with_statement, print_function, absolute_import +import math +import numpy as np +from wavenet_vocoder import conv +from mindspore import nn +from mindspore.ops import operations as P + + +def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): + m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs) + return m + + +def Conv1d1x1(in_channels, out_channels, has_bias=True): + return Conv1d(in_channels, out_channels, kernel_size=1, pad_mode='pad', padding=0, dilation=1, has_bias=has_bias) + + +def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + return m + + +def _conv1x1_forward(conv_, x, is_incremental, is_numpy=True): + """ + Conv1x1 forward + """ + if is_incremental: + x = conv_.incremental_forward(x, is_numpy=is_numpy) + else: + x = conv_(x) + return x + + +class ResidualConv1dGLU(nn.Cell): + """Residual dilated conv1d with gated activation units + + Args: + residual_channels (int): Residual input / output channels + gate_channels (int): Gated activation channels. + kernel_size (int): Kernel size + skip_out_channels (int): Skip connection channels. If None, it will set to the same as residual_channels. + cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled. + gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled. + dropout (float): Dropout rate. + padding (int): Padding for convolution layers. If None, padding value will be computed according to dilation + and kernel_size. + dilation (int): Dilation factor. + + """ + + def __init__(self, residual_channels=None, gate_channels=None, kernel_size=None, skip_out_channels=None, bias=True, + dropout=1 - 0.95, dilation=1, cin_channels=-1, gin_channels=-1, padding=None, causal=True): + super(ResidualConv1dGLU, self).__init__() + self.dropout = dropout + self.dropout_op = nn.Dropout(p=self.dropout) + self.eval_split_op = P.Split(axis=-1, output_num=2) + self.train_split_op = P.Split(axis=1, output_num=2) + self.tanh = P.Tanh() + self.sigmoid = P.Sigmoid() + self.mul = P.Mul() + self.add = P.Add() + + if skip_out_channels is None: + skip_out_channels = residual_channels + if padding is None: + if causal: + padding = (kernel_size - 1) * dilation + else: + padding = (kernel_size - 1) // 2 * dilation + self.causal = causal + + self.conv = Conv1d(residual_channels, gate_channels, kernel_size, pad_mode='pad', + padding=padding, dilation=dilation, has_bias=bias) + + # local conditioning + if cin_channels > 0: + self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, has_bias=False) + else: + self.conv1x1c = None + + # global conditioning + if gin_channels > 0: + self.conv1x1g = Conv1d(gin_channels, gate_channels, has_bias=False, kernel_size=1, dilation=1) + else: + self.conv1x1g = None + + gate_out_channels = gate_channels // 2 + self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, has_bias=bias) + self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, has_bias=bias) + self.factor = math.sqrt(0.5) + + def construct(self, x, c=None, g=None): + """ + + Args: + x(Tensor): One-hot audio signal, the shape is B x C x T + c(Tensor): local conditional feature, the shape is B x cin_channels x T + g(Tensor): global conditional feature, not used currently + + Returns: + Tensor: Output tensor + + """ + + residual = x + x = self.dropout_op(x) + x = self.conv(x) + # remove future time steps + x = x[:, :, :residual.shape[-1]] if self.causal else x + split_op = self.train_split_op + + a, b = split_op(x) + + # local conditioning + if c is not None: + c = _conv1x1_forward(self.conv1x1c, c, is_incremental=False) + ca, cb = split_op(c) + a, b = a + ca, b + cb + + # global conditioning + if g is not None: + g = _conv1x1_forward(self.conv1x1g, g, is_incremental=False) + ga, gb = self.split(g) + a, b = a + ga, b + gb + + x = self.mul(self.tanh(a), self.sigmoid(b)) + + # For skip connection + s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental=False) + + # For residual connection + x = _conv1x1_forward(self.conv1x1_out, x, is_incremental=False) + + x = self.add(x, residual) * self.factor + return x, s + + def sigmoid_numpy(self, x): + return 1. / (1 + np.exp(-x)) + + def incremental_forward(self, x, c=None, g=None, is_numpy=True): + """ + Incremental forward. Used for inference stage + + Args: + x (Tensor): One-hot audio signal, the shape is B x C x T + c (Tensor): local conditional feature, the shape is B x cin_channels x T + g (Tensor): global conditional feature, not used currently + + Returns: + ndarray + """ + residual = x + x = self.conv.incremental_forward(x, is_numpy=is_numpy) + if is_numpy: + a, b = np.split(x, indices_or_sections=2, axis=-1) + else: + a, b = self.eval_split_op(x) + + # local conditioning + if c is not None: + c = _conv1x1_forward(self.conv1x1c, c, is_incremental=True, is_numpy=is_numpy) + if is_numpy: + ca, cb = np.split(c, indices_or_sections=2, axis=-1) + else: + ca, cb = self.eval_split_op(c) + a, b = a + ca, b + cb + + # global conditioning + if g is not None: + g = _conv1x1_forward(self.conv1x1g, g, is_incremental=True, is_numpy=is_numpy) + if is_numpy: + ga, gb = np.split(g, indices_or_sections=2, axis=-1) + else: + ga, gb = self.eval_split_op(c) + a, b = a + ga, b + gb + + if is_numpy: + x = np.tanh(a) * self.sigmoid_numpy(b) + else: + x = self.mul(self.tanh(a), self.sigmoid(b)) + + # For skip connection + s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental=True, is_numpy=is_numpy) + + # For residual connection + x = _conv1x1_forward(self.conv1x1_out, x, is_incremental=True, is_numpy=is_numpy) + + x = (x + residual) * self.factor + return x, s + + def clear_buffer(self): + """clear buffer""" + for c in [self.conv, self.conv1x1_out, self.conv1x1_skip, + self.conv1x1c, self.conv1x1g]: + if c is not None: + c.clear_buffer() diff --git a/HuaWeiExperiment/wavenet/src/wavenet_vocoder/tfcompat/__init__.py b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/tfcompat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/HuaWeiExperiment/wavenet/src/wavenet_vocoder/tfcompat/hparam.py b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/tfcompat/hparam.py new file mode 100644 index 00000000..c428176b --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/tfcompat/hparam.py @@ -0,0 +1,726 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hyperparameter values.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import numbers +import re + +import six + +## from tensorflow.contrib.training.python.training import hparam_pb2 +## from tensorflow.python.framework import ops +## from tensorflow.python.util import compat +## from tensorflow.python.util import deprecation + +# Define the regular expression for parsing a single clause of the input +# (delimited by commas). A legal clause looks like: +# []? = +# where is either a single token or [] enclosed list of tokens. +# For example: "var[1] = a" or "x = [1,2,3]" +PARAM_RE = re.compile(r""" + (?P[a-zA-Z][\w\.]*) # variable name: "var" or "x" + (\[\s*(?P\d+)\s*\])? # (optional) index: "1" or None + \s*=\s* + ((?P[^,\[]*) # single value: "a" or None + | + \[(?P[^\]]*)\]) # list of values: None or "1,2,3" + ($|,\s*)""", re.VERBOSE) + + +def _parse_fail(name, var_type, value, values): + """Helper function for raising a value error for bad assignment.""" + raise ValueError( + 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' % + (name, var_type.__name__, value, values)) + + +def _reuse_fail(name, values): + """Helper function for raising a value error for reuse of name.""" + raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name, + values)) + + +def _process_scalar_value(name, parse_fn, var_type, m_dict, values, + results_dictionary): + """Update results_dictionary with a scalar value. + + Used to update the results_dictionary to be returned by parse_values when + encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) + + Mutates results_dictionary. + + Args: + name: Name of variable in assignment ("s" or "arr"). + parse_fn: Function for parsing the actual value. + var_type: Type of named variable. + m_dict: Dictionary constructed from regex parsing. + m_dict['val']: RHS value (scalar) + m_dict['index']: List index value (or None) + values: Full expression being parsed + results_dictionary: The dictionary being updated for return by the parsing + function. + + Raises: + ValueError: If the name has already been used. + """ + try: + parsed_value = parse_fn(m_dict['val']) + except ValueError: + _parse_fail(name, var_type, m_dict['val'], values) + + # If no index is provided + if not m_dict['index']: + if name in results_dictionary: + _reuse_fail(name, values) + results_dictionary[name] = parsed_value + else: + if name in results_dictionary: + # The name has already been used as a scalar, then it + # will be in this dictionary and map to a non-dictionary. + if not isinstance(results_dictionary.get(name), dict): + _reuse_fail(name, values) + else: + results_dictionary[name] = {} + + index = int(m_dict['index']) + # Make sure the index position hasn't already been assigned a value. + if index in results_dictionary[name]: + _reuse_fail('{}[{}]'.format(name, index), values) + results_dictionary[name][index] = parsed_value + + +def _process_list_value(name, parse_fn, var_type, m_dict, values, + results_dictionary): + """Update results_dictionary from a list of values. + + Used to update results_dictionary to be returned by parse_values when + encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) + + Mutates results_dictionary. + + Args: + name: Name of variable in assignment ("arr"). + parse_fn: Function for parsing individual values. + var_type: Type of named variable. + m_dict: Dictionary constructed from regex parsing. + m_dict['val']: RHS value (scalar) + values: Full expression being parsed + results_dictionary: The dictionary being updated for return by the parsing + function. + + Raises: + ValueError: If the name has an index or the values cannot be parsed. + """ + if m_dict['index'] is not None: + raise ValueError('Assignment of a list to a list index.') + elements = filter(None, re.split('[ ,]', m_dict['vals'])) + # Make sure the name hasn't already been assigned a value + if name in results_dictionary: + raise _reuse_fail(name, values) + try: + results_dictionary[name] = [parse_fn(e) for e in elements] + except ValueError: + _parse_fail(name, var_type, m_dict['vals'], values) + + +def _cast_to_type_if_compatible(name, param_type, value): + """Cast hparam to the provided type, if compatible. + + Args: + name: Name of the hparam to be cast. + param_type: The type of the hparam. + value: The value to be cast, if compatible. + + Returns: + The result of casting `value` to `param_type`. + + Raises: + ValueError: If the type of `value` is not compatible with param_type. + * If `param_type` is a string type, but `value` is not. + * If `param_type` is a boolean, but `value` is not, or vice versa. + * If `param_type` is an integer type, but `value` is not. + * If `param_type` is a float type, but `value` is not a numeric type. + """ + fail_msg = ( + "Could not cast hparam '%s' of type '%s' from value %r" % + (name, param_type, value)) + + # Some callers use None, for which we can't do any casting/checking. :( + if issubclass(param_type, type(None)): + return value + + # Avoid converting a non-string type to a string. + if (issubclass(param_type, (six.string_types, six.binary_type)) and + not isinstance(value, (six.string_types, six.binary_type))): + raise ValueError(fail_msg) + + # Avoid converting a number or string type to a boolean or vice versa. + if issubclass(param_type, bool) != isinstance(value, bool): + raise ValueError(fail_msg) + + # Avoid converting float to an integer (the reverse is fine). + if (issubclass(param_type, numbers.Integral) and + not isinstance(value, numbers.Integral)): + raise ValueError(fail_msg) + + # Avoid converting a non-numeric type to a numeric type. + if (issubclass(param_type, numbers.Number) and + not isinstance(value, numbers.Number)): + raise ValueError(fail_msg) + + return param_type(value) + + +def parse_values(values, type_map): + """Parses hyperparameter values from a string into a python map. + + `values` is a string containing comma-separated `name=value` pairs. + For each pair, the value of the hyperparameter named `name` is set to + `value`. + + If a hyperparameter name appears multiple times in `values`, a ValueError + is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). + + If a hyperparameter name in both an index assignment and scalar assignment, + a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). + + The hyperparameter name may contain '.' symbols, which will result in an + attribute name that is only accessible through the getattr and setattr + functions. (And must be first explicit added through add_hparam.) + + WARNING: Use of '.' in your variable names is allowed, but is not well + supported and not recommended. + + The `value` in `name=value` must follows the syntax according to the + type of the parameter: + + * Scalar integer: A Python-parsable integer point value. E.g.: 1, + 100, -12. + * Scalar float: A Python-parsable floating point value. E.g.: 1.0, + -.54e89. + * Boolean: Either true or false. + * Scalar string: A non-empty sequence of characters, excluding comma, + spaces, and square brackets. E.g.: foo, bar_1. + * List: A comma separated list of scalar values of the parameter type + enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. + + When index assignment is used, the corresponding type_map key should be the + list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not + "arr[1]"). + + Args: + values: String. Comma separated list of `name=value` pairs where + 'value' must follow the syntax described above. + type_map: A dictionary mapping hyperparameter names to types. Note every + parameter name in values must be a key in type_map. The values must + conform to the types indicated, where a value V is said to conform to a + type T if either V has type T, or V is a list of elements of type T. + Hence, for a multidimensional parameter 'x' taking float values, + 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. + + Returns: + A python map mapping each name to either: + * A scalar value. + * A list of scalar values. + * A dictionary mapping index numbers to scalar values. + (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") + + Raises: + ValueError: If there is a problem with input. + * If `values` cannot be parsed. + * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). + * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', + 'a[1]=1,a[1]=2', or 'a=1,a=[1]') + """ + results_dictionary = {} + pos = 0 + while pos < len(values): + m = PARAM_RE.match(values, pos) + if not m: + raise ValueError('Malformed hyperparameter value: %s' % values[pos:]) + # Check that there is a comma between parameters and move past it. + pos = m.end() + # Parse the values. + m_dict = m.groupdict() + name = m_dict['name'] + if name not in type_map: + raise ValueError('Unknown hyperparameter type for %s' % name) + type_ = type_map[name] + + # Set up correct parsing function (depending on whether type_ is a bool) + if type_ == bool: + + def parse_bool(value): + if value in ['true', 'True']: + return True + elif value in ['false', 'False']: + return False + else: + try: + return bool(int(value)) + except ValueError: + _parse_fail(name, type_, value, values) + + parse = parse_bool + else: + parse = type_ + + # If a singe value is provided + if m_dict['val'] is not None: + _process_scalar_value(name, parse, type_, m_dict, values, + results_dictionary) + + # If the assigned value is a list: + elif m_dict['vals'] is not None: + _process_list_value(name, parse, type_, m_dict, values, + results_dictionary) + + else: # Not assigned a list or value + _parse_fail(name, type_, '', values) + + return results_dictionary + + +class HParams(object): + """Class to hold a set of hyperparameters as name-value pairs. + + A `HParams` object holds hyperparameters used to build and train a model, + such as the number of hidden units in a neural net layer or the learning rate + to use when training. + + You first create a `HParams` object by specifying the names and values of the + hyperparameters. + + To make them easily accessible the parameter names are added as direct + attributes of the class. A typical usage is as follows: + + ```python + # Create a HParams object specifying names and values of the model + # hyperparameters: + hparams = HParams(learning_rate=0.1, num_hidden_units=100) + + # The hyperparameter are available as attributes of the HParams object: + hparams.learning_rate ==> 0.1 + hparams.num_hidden_units ==> 100 + ``` + + Hyperparameters have type, which is inferred from the type of their value + passed at construction type. The currently supported types are: integer, + float, boolean, string, and list of integer, float, boolean, or string. + + You can override hyperparameter values by calling the + [`parse()`](#HParams.parse) method, passing a string of comma separated + `name=value` pairs. This is intended to make it possible to override + any hyperparameter values from a single command-line flag to which + the user passes 'hyper-param=value' pairs. It avoids having to define + one flag for each hyperparameter. + + The syntax expected for each value depends on the type of the parameter. + See `parse()` for a description of the syntax. + + Example: + + ```python + # Define a command line flag to pass name=value pairs. + # For example using argparse: + import argparse + parser = argparse.ArgumentParser(description='Train my model.') + parser.add_argument('--hparams', type=str, + help='Comma separated list of "name=value" pairs.') + args = parser.parse_args() + ... + def my_program(): + # Create a HParams object specifying the names and values of the + # model hyperparameters: + hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, + activations=['relu', 'tanh']) + + # Override hyperparameters values by parsing the command line + hparams.parse(args.hparams) + + # If the user passed `--hparams=learning_rate=0.3` on the command line + # then 'hparams' has the following attributes: + hparams.learning_rate ==> 0.3 + hparams.num_hidden_units ==> 100 + hparams.activations ==> ['relu', 'tanh'] + + # If the hyperparameters are in json format use parse_json: + hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') + ``` + """ + + _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. + + def __init__(self, hparam_def=None, model_structure=None, **kwargs): + """Create an instance of `HParams` from keyword arguments. + + The keyword arguments specify name-values pairs for the hyperparameters. + The parameter types are inferred from the type of the values passed. + + The parameter names are added as attributes of `HParams` object, so they + can be accessed directly with the dot notation `hparams._name_`. + + Example: + + ```python + # Define 3 hyperparameters: 'learning_rate' is a float parameter, + # 'num_hidden_units' an integer parameter, and 'activation' a string + # parameter. + hparams = tf.HParams( + learning_rate=0.1, num_hidden_units=100, activation='relu') + + hparams.activation ==> 'relu' + ``` + + Note that a few names are reserved and cannot be used as hyperparameter + names. If you use one of the reserved name the constructor raises a + `ValueError`. + + Args: + hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef + protocol buffer. If provided, this object is initialized by + deserializing hparam_def. Otherwise **kwargs is used. + model_structure: An instance of ModelStructure, defining the feature + crosses to be used in the Trial. + **kwargs: Key-value pairs where the key is the hyperparameter name and + the value is the value for the parameter. + + Raises: + ValueError: If both `hparam_def` and initialization values are provided, + or if one of the arguments is invalid. + + """ + # Register the hyperparameters and their type in _hparam_types. + # This simplifies the implementation of parse(). + # _hparam_types maps the parameter name to a tuple (type, bool). + # The type value is the type of the parameter for scalar hyperparameters, + # or the type of the list elements for multidimensional hyperparameters. + # The bool value is True if the value is a list, False otherwise. + self._hparam_types = {} + self._model_structure = model_structure + if hparam_def: +## self._init_from_proto(hparam_def) +## if kwargs: +## raise ValueError('hparam_def and initialization values are ' +## 'mutually exclusive') + raise ValueError('hparam_def has been disabled in this version') + else: + for name, value in six.iteritems(kwargs): + self.add_hparam(name, value) + +## def _init_from_proto(self, hparam_def): +## """Creates a new HParams from `HParamDef` protocol buffer. +## +## Args: +## hparam_def: `HParamDef` protocol buffer. +## """ +## assert isinstance(hparam_def, hparam_pb2.HParamDef) +## for name, value in hparam_def.hparam.items(): +## kind = value.WhichOneof('kind') +## if kind.endswith('_value'): +## # Single value. +## if kind.startswith('int64'): +## # Setting attribute value to be 'int' to ensure the type is compatible +## # with both Python2 and Python3. +## self.add_hparam(name, int(getattr(value, kind))) +## elif kind.startswith('bytes'): +## # Setting attribute value to be 'str' to ensure the type is compatible +## # with both Python2 and Python3. UTF-8 encoding is assumed. +## self.add_hparam(name, compat.as_str(getattr(value, kind))) +## else: +## self.add_hparam(name, getattr(value, kind)) +## else: +## # List of values. +## if kind.startswith('int64'): +## # Setting attribute value to be 'int' to ensure the type is compatible +## # with both Python2 and Python3. +## self.add_hparam(name, [int(v) for v in getattr(value, kind).value]) +## elif kind.startswith('bytes'): +## # Setting attribute value to be 'str' to ensure the type is compatible +## # with both Python2 and Python3. UTF-8 encoding is assumed. +## self.add_hparam( +## name, [compat.as_str(v) for v in getattr(value, kind).value]) +## else: +## self.add_hparam(name, [v for v in getattr(value, kind).value]) + + def add_hparam(self, name, value): + """Adds {name, value} pair to hyperparameters. + + Args: + name: Name of the hyperparameter. + value: Value of the hyperparameter. Can be one of the following types: + int, float, string, int list, float list, or string list. + + Raises: + ValueError: if one of the arguments is invalid. + """ + # Keys in kwargs are unique, but 'name' could the name of a pre-existing + # attribute of this object. In that case we refuse to use it as a + # hyperparameter name. + if getattr(self, name, None) is not None: + raise ValueError('Hyperparameter name is reserved: %s' % name) + if isinstance(value, (list, tuple)): + if not value: + raise ValueError( + 'Multi-valued hyperparameters cannot be empty: %s' % name) + self._hparam_types[name] = (type(value[0]), True) + else: + self._hparam_types[name] = (type(value), False) + setattr(self, name, value) + + def set_hparam(self, name, value): + """Set the value of an existing hyperparameter. + + This function verifies that the type of the value matches the type of the + existing hyperparameter. + + Args: + name: Name of the hyperparameter. + value: New value of the hyperparameter. + + Raises: + ValueError: If there is a type mismatch. + """ + param_type, is_list = self._hparam_types[name] + if isinstance(value, list): + if not is_list: + raise ValueError( + 'Must not pass a list for single-valued parameter: %s' % name) + setattr(self, name, [ + _cast_to_type_if_compatible(name, param_type, v) for v in value]) + else: + if is_list: + raise ValueError( + 'Must pass a list for multi-valued parameter: %s.' % name) + setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) + + def del_hparam(self, name): + """Removes the hyperparameter with key 'name'. + + Args: + name: Name of the hyperparameter. + """ + if hasattr(self, name): + delattr(self, name) + del self._hparam_types[name] + + def parse(self, values): + """Override hyperparameter values, parsing new values from a string. + + See parse_values for more detail on the allowed format for values. + + Args: + values: String. Comma separated list of `name=value` pairs where + 'value' must follow the syntax described above. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values` cannot be parsed. + """ + type_map = dict() + for name, t in self._hparam_types.items(): + param_type, _ = t + type_map[name] = param_type + + values_map = parse_values(values, type_map) + return self.override_from_dict(values_map) + + def override_from_dict(self, values_dict): + """Override hyperparameter values, parsing new values from a dictionary. + + Args: + values_dict: Dictionary of name:value pairs. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values_dict` cannot be parsed. + """ + for name, value in values_dict.items(): + self.set_hparam(name, value) + return self + +## @deprecation.deprecated(None, 'Use `override_from_dict`.') + def set_from_map(self, values_map): + """DEPRECATED. Use override_from_dict.""" + return self.override_from_dict(values_dict=values_map) + + def set_model_structure(self, model_structure): + self._model_structure = model_structure + + def get_model_structure(self): + return self._model_structure + + def to_json(self, indent=None, separators=None, sort_keys=False): + """Serializes the hyperparameters into JSON. + + Args: + indent: If a non-negative integer, JSON array elements and object members + will be pretty-printed with that indent level. An indent level of 0, or + negative, will only insert newlines. `None` (the default) selects the + most compact representation. + separators: Optional `(item_separator, key_separator)` tuple. Default is + `(', ', ': ')`. + sort_keys: If `True`, the output dictionaries will be sorted by key. + + Returns: + A JSON string. + """ + return json.dumps( + self.values(), + indent=indent, + separators=separators, + sort_keys=sort_keys) + + def parse_json(self, values_json): + """Override hyperparameter values, parsing new values from a json object. + + Args: + values_json: String containing a json object of name:value pairs. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values_json` cannot be parsed. + """ + values_map = json.loads(values_json) + return self.override_from_dict(values_map) + + def values(self): + """Return the hyperparameter values as a Python dictionary. + + Returns: + A dictionary with hyperparameter names as keys. The values are the + hyperparameter values. + """ + return {n: getattr(self, n) for n in self._hparam_types.keys()} + + def get(self, key, default=None): + """Returns the value of `key` if it exists, else `default`.""" + if key in self._hparam_types: + # Ensure that default is compatible with the parameter type. + if default is not None: + param_type, is_param_list = self._hparam_types[key] + type_str = 'list<%s>' % param_type if is_param_list else str(param_type) + fail_msg = ("Hparam '%s' of type '%s' is incompatible with " + 'default=%s' % (key, type_str, default)) + + is_default_list = isinstance(default, list) + if is_param_list != is_default_list: + raise ValueError(fail_msg) + + try: + if is_default_list: + for value in default: + _cast_to_type_if_compatible(key, param_type, value) + else: + _cast_to_type_if_compatible(key, param_type, default) + except ValueError as e: + raise ValueError('%s. %s' % (fail_msg, e)) + + return getattr(self, key) + + return default + + def __contains__(self, key): + return key in self._hparam_types + + def __str__(self): + return str(sorted(self.values().items())) + + def __repr__(self): + return '%s(%s)' % (type(self).__name__, self.__str__()) + + @staticmethod + def _get_kind_name(param_type, is_list): + """Returns the field name given parameter type and is_list. + + Args: + param_type: Data type of the hparam. + is_list: Whether this is a list. + + Returns: + A string representation of the field name. + + Raises: + ValueError: If parameter type is not recognized. + """ + if issubclass(param_type, bool): + # This check must happen before issubclass(param_type, six.integer_types), + # since Python considers bool to be a subclass of int. + typename = 'bool' + elif issubclass(param_type, six.integer_types): + # Setting 'int' and 'long' types to be 'int64' to ensure the type is + # compatible with both Python2 and Python3. + typename = 'int64' + elif issubclass(param_type, (six.string_types, six.binary_type)): + # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is + # compatible with both Python2 and Python3. + typename = 'bytes' + elif issubclass(param_type, float): + typename = 'float' + else: + raise ValueError('Unsupported parameter type: %s' % str(param_type)) + + suffix = 'list' if is_list else 'value' + return '_'.join([typename, suffix]) + +## def to_proto(self, export_scope=None): # pylint: disable=unused-argument +## """Converts a `HParams` object to a `HParamDef` protocol buffer. +## +## Args: +## export_scope: Optional `string`. Name scope to remove. +## +## Returns: +## A `HParamDef` protocol buffer. +## """ +## hparam_proto = hparam_pb2.HParamDef() +## for name in self._hparam_types: +## # Parse the values. +## param_type, is_list = self._hparam_types.get(name, (None, None)) +## kind = HParams._get_kind_name(param_type, is_list) +## +## if is_list: +## if kind.startswith('bytes'): +## v_list = [compat.as_bytes(v) for v in getattr(self, name)] +## else: +## v_list = [v for v in getattr(self, name)] +## getattr(hparam_proto.hparam[name], kind).value.extend(v_list) +## else: +## v = getattr(self, name) +## if kind.startswith('bytes'): +## v = compat.as_bytes(getattr(self, name)) +## setattr(hparam_proto.hparam[name], kind, v) +## +## return hparam_proto + +## @staticmethod +## def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument +## return HParams(hparam_def=hparam_def) + + +## ops.register_proto_function( +## 'hparams', +## proto_type=hparam_pb2.HParamDef, +## to_proto=HParams.to_proto, +## from_proto=HParams.from_proto) diff --git a/HuaWeiExperiment/wavenet/src/wavenet_vocoder/tfcompat/readme.md b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/tfcompat/readme.md new file mode 100644 index 00000000..3d94e4c4 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/tfcompat/readme.md @@ -0,0 +1,8 @@ +Source: hparam.py copied from tensorflow v1.12.0. + +https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/training/python/training/hparam.py + +with the following: +wget https://github.com/tensorflow/tensorflow/raw/v1.12.0/tensorflow/contrib/training/python/training/hparam.py + +Once all other tensorflow dependencies of these file are removed, the class keeps its goal. Functions not available due to this process are not used in this project. diff --git a/HuaWeiExperiment/wavenet/src/wavenet_vocoder/upsample.py b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/upsample.py new file mode 100644 index 00000000..32b4ba15 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/upsample.py @@ -0,0 +1,111 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Upsampling. +""" +from __future__ import with_statement, print_function, absolute_import +import numpy as np +from mindspore import nn +from mindspore.ops import operations as P + + +class Resize(nn.Cell): + """ + Resize input Tensor + """ + + def __init__(self, x_scale, y_scale, mode="nearest"): + super(Resize, self).__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def construct(self, x): + _, _, h, w = x.shape + interpolate_op = P.ResizeNearestNeighbor((self.y_scale * h, self.x_scale * w)) + return interpolate_op(x) + + +def _get_activation(upsample_activation): + """get activation""" + nonlinear = getattr(nn, upsample_activation) + return nonlinear + + +class UpsampleNetwork(nn.Cell): + """UpsampleNetwork""" + def __init__(self, upsample_scales, mode="nearest", + freq_axis_kernel_size=1, cin_pad=0, cin_channels=80): + super(UpsampleNetwork, self).__init__() + self.expand_op = P.ExpandDims() + self.squeeze_op = P.Squeeze(1) + up_layers = [] + total_scale = np.prod(upsample_scales) + self.indent = cin_pad * total_scale + for scale in upsample_scales: + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + k_size = (freq_axis_kernel_size, scale * 2 + 1) + padding = (freq_axis_padding, freq_axis_padding, scale, scale) + stretch = Resize(scale, 1, mode) + conv = nn.Conv2d(1, 1, kernel_size=k_size, has_bias=False, pad_mode='pad', padding=padding) + up_layers.append(stretch) + up_layers.append(conv) + self.up_layers = nn.CellList(up_layers) + + def construct(self, c): + """ + + Args: + c (Tensor): Local conditioning feature + + Returns: + Tensor: Upsampling feature + + """ + # B x 1 x C x T + c = self.expand_op(c, 1) + for f in self.up_layers: + c = f(c) + # B x C x T + c = self.squeeze_op(c) + + return c + + +class ConvInUpsampleNetwork(nn.Cell): + """Upsample Network + + Args: + upsample_scales (list): Upsample_scales list. + upsample_activation (str): Upsample_activation. + mode (str): Resize mode, default is NearestNeighbor. + cin_channels (int): Local conditioning channels. + freq_axis_kernel_size (int): Freq-axis kernel_size for the convolution layers after resize. + + """ + + def __init__(self, upsample_scales, mode="nearest", + freq_axis_kernel_size=1, cin_pad=0, + cin_channels=80): + super(ConvInUpsampleNetwork, self).__init__() + ks = 2 * cin_pad + 1 + self.conv_in = nn.Conv1d(cin_channels, cin_channels, kernel_size=ks, has_bias=False, pad_mode='pad', padding=0) + self.upsample = UpsampleNetwork(upsample_scales, mode, freq_axis_kernel_size, cin_pad=0, + cin_channels=cin_channels) + + def construct(self, c): + c = self.conv_in(c) + c_up = self.upsample(c) + return c_up diff --git a/HuaWeiExperiment/wavenet/src/wavenet_vocoder/util.py b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/util.py new file mode 100644 index 00000000..4fea1d98 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/util.py @@ -0,0 +1,25 @@ +# coding: utf-8 +from __future__ import with_statement, print_function, absolute_import + + +def _assert_valid_input_type(s): + assert s == "mulaw-quantize" or s == "mulaw" or s == "raw" + + +def is_mulaw_quantize(s): + _assert_valid_input_type(s) + return s == "mulaw-quantize" + + +def is_mulaw(s): + _assert_valid_input_type(s) + return s == "mulaw" + + +def is_raw(s): + _assert_valid_input_type(s) + return s == "raw" + + +def is_scalar_input(s): + return is_raw(s) or is_mulaw(s) diff --git a/HuaWeiExperiment/wavenet/src/wavenet_vocoder/wavenet.py b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/wavenet.py new file mode 100644 index 00000000..50ca5b03 --- /dev/null +++ b/HuaWeiExperiment/wavenet/src/wavenet_vocoder/wavenet.py @@ -0,0 +1,335 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +WaveNet construction. +""" +from __future__ import with_statement, print_function, absolute_import + +import math +import numpy as np + +from mindspore import nn, Tensor +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from wavenet_vocoder import upsample +from .modules import Embedding +from .modules import Conv1d1x1 +from .modules import ResidualConv1dGLU +from .mixture import sample_from_discretized_mix_logistic +from .mixture import sample_from_mix_gaussian +from .mixture import sample_from_mix_onehotcategorical + + +class WaveNet(nn.Cell): + """ + WaveNet model definition. Only local condition is supported + + Args: + out_channels (int): Output channels. If input_type is mu-law quantized one-hot vecror, it should equal to the + quantize channels. Otherwise, it equals to num_mixtures x 3. Default: 256. + layers (int): Number of ResidualConv1dGLU layers + stacks (int): Number of dilation cycles + residual_channels (int): Residual input / output channels + gate_channels (int): Gated activation channels. + skip_out_channels (int): Skip connection channels. + kernel_size (int): Kernel size . + dropout (float): Dropout rate. + cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled. + gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled. + n_speakers (int): Number of speakers. This is used when global conditioning is enabled. + upsample_conditional_features (bool): Whether upsampling local conditioning features by resize_nearestneighbor + and conv or not. + scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise, quantized one-hot vector + is expected. + use_speaker_embedding (Bool): Use speaker embedding or Not. + + """ + + def __init__(self, out_channels=256, layers=20, stacks=2, + residual_channels=512, + gate_channels=512, + skip_out_channels=512, + kernel_size=3, dropout=1 - 0.95, + cin_channels=-1, gin_channels=-1, n_speakers=None, + upsample_conditional_features=False, + upsample_net="ConvInUpsampleNetwork", + upsample_params=None, + scalar_input=False, + use_speaker_embedding=False, + output_distribution="Logistic", + cin_pad=0, + ): + super(WaveNet, self).__init__() + self.transpose_op = P.Transpose() + self.softmax = P.Softmax(axis=1) + self.reshape_op = P.Reshape() + self.zeros_op = P.Zeros() + self.ones_op = P.Ones() + self.squeeze_op = P.Squeeze() + self.expandim_op = P.ExpandDims() + self.transpose_op = P.Transpose() + self.tile_op = P.Tile() + self.scalar_input = scalar_input + self.out_channels = out_channels + self.cin_channels = cin_channels + self.output_distribution = output_distribution + self.fack_data = P.Zeros() + assert layers % stacks == 0 + layers_per_stack = layers // stacks # 24 / 4 = 6 + if scalar_input: + self.first_conv = Conv1d1x1(1, residual_channels) + else: + self.first_conv = Conv1d1x1(out_channels, residual_channels) + + conv_layers = [] + for layer in range(layers): + dilation = 2 ** (layer % layers_per_stack) # 1, 2, 4, 8, 16, 32 + conv = ResidualConv1dGLU( + residual_channels, gate_channels, + kernel_size=kernel_size, + skip_out_channels=skip_out_channels, + bias=True, + dropout=dropout, + dilation=dilation, + cin_channels=cin_channels, + gin_channels=gin_channels) + conv_layers.append(conv) + self.conv_layers = nn.CellList(conv_layers) + self.last_conv_layers = nn.CellList([ + nn.ReLU(), + Conv1d1x1(skip_out_channels, skip_out_channels), + nn.ReLU(), + Conv1d1x1(skip_out_channels, out_channels)]) + + if gin_channels > 0 and use_speaker_embedding: + assert n_speakers is not None + self.embed_speakers = Embedding( + n_speakers, gin_channels, padding_idx=None, std=0.1) + else: + self.embed_speakers = None + + if upsample_conditional_features: + self.upsample_net = getattr(upsample, upsample_net)(**upsample_params) + else: + self.upsample_net = None + + self.factor = math.sqrt(1.0 / len(self.conv_layers)) # sqrt( 1 / 24) + + def _expand_global_features(self, batch_size, time_step, g_fp, is_expand=True): + """Expand global conditioning features to all time steps + + Args: + batch_size (int): Batch size. + time_step (int): Time length. + g_fp (Tensor): Global features, (B x C) or (B x C x 1). + is_expand (bool) : Expanded global conditioning features + + Returns: + Tensor: B x C x T or B x T x C or None + """ + if g_fp is None: + return None + if len(g_fp.shape) == 2: + g_fp = self.expandim_op(g_fp, -1) + else: + g_fp = g_fp + + if is_expand: + expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step)) + return expand_fp + expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step)) + expand_fp = self.transpose_op(expand_fp, (0, 2, 1)) + return expand_fp + + def construct(self, x, cond=None, g=None, softmax=False): + """ + + Args: + x (Tensor): One-hot encoded audio signal + c (Tensor): Local conditioning feature + g (Tensor): Global conditioning feature + softmax (bool): Whether use softmax or not + + Returns: + Tensor: Net output + + """ + + g = None + B, _, T = x.shape + if g is not None: + if self.embed_speakers is not None: + g = self.embed_speakers(self.reshape_op(g, (B, -1))) + g = self.transpose_op(g, (0, 2, 1)) + g_bct = self._expand_global_features(B, T, g, is_expand=True) # None + + if cond is not None and self.upsample_net is not None: + cond = self.upsample_net(cond) # [B, 128, 10240] + + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, hidden = f(x, cond, g_bct) # x=[B, 128, 10240], hidden=[B, 128, 10240] + skips += hidden + skips *= self.factor + + x = skips # x=[B, 128, 10240] + for f in self.last_conv_layers: + x = f(x) # x=[B, 2, 10240] + x = self.softmax(x) if softmax else x + + return x + + def relu_numpy(self, inX): + """numpy relu function""" + return np.maximum(0, inX) + + def softmax_numpy(self, x): + """ numpy softmax function """ + x -= np.max(x, axis=1, keepdims=True) + return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) + + def incremental_forward(self, initial_input=None, c_=None, g=None, + T=100, test_inputs=None, + tqdm=lambda x: x, softmax=True, quantize=True, + log_scale_min=-50.0, is_numpy=True): + """ + Incremental forward. Current output depends on last output. + + Args: + initial_input (Tensor): Initial input, the shape is B x C x 1 + c (Tensor): Local conditioning feature, the shape is B x C x T + g (Tensor): Global conditioning feature, the shape is B x C or B x C x 1 + T (int): decoding time step. + test_inputs: Teacher forcing inputs (for debugging) + tqdm (lamda): tqmd + softmax (bool): Whether use softmax or not + quantize (bool): Whether quantize softmax output in last step when decoding current step + log_scale_min (float): Log scale minimum value + + Returns: + Tensor: Predicted on-hot encoded samples or scalar vector depending on loss type + + """ + + self.clear_buffer() + B = 1 + + if test_inputs is not None: + if self.scalar_input: + if test_inputs.shape[1] == 1: + test_inputs = self.transpose_op(test_inputs, (0, 2, 1)) + else: + if test_inputs.shape[1] == self.out_channels: + test_inputs = self.transpose_op(test_inputs, (0, 2, 1)) + + B = test_inputs.shape[0] + if T is None: + T = test_inputs.shape[1] + else: + T = max(T, test_inputs.shape[1]) + T = int(T) + + # Global conditioning + if g is not None: + if self.embed_speakers is not None: + g = self.embed_speakers(self.reshape_op(g, (B, -1))) + g = self.transpose_op(g, (0, 2, 1)) + assert g.dim() == 3 + g_btc = self._expand_global_features(B, T, g, is_expand=False) + + # Local conditioning + if c_ is not None: + B = c_.shape[0] + if self.upsample_net is not None: + c_ = self.upsample_net(c_) + assert c_.shape[-1] == T + if c_.shape[-1] == T: + c_ = self.transpose_op(c_, (0, 2, 1)) + + outputs = [] + if initial_input is None: + if self.scalar_input: + initial_input = self.zeros_op((B, 1, 1), mstype.float32) + else: + initial_input = np.zeros((B, 1, self.out_channels), np.float32) + initial_input[:, :, 127] = 1 + initial_input = Tensor(initial_input) + else: + if initial_input.shape[1] == self.out_channels: + initial_input = self.transpose_op(initial_input, (0, 2, 1)) + + current_input = initial_input.asnumpy() + + for t in tqdm(range(T)): + if test_inputs is not None and t < test_inputs.shape[1]: + current_input = self.expandim_op(test_inputs[:, t, :], 1) + else: + if t > 0: + current_input = outputs[-1] + + # Conditioning features for single time step + ct = None if c_ is None else self.expandim_op(c_[:, t, :], 1) + gt = None if g is None else self.expandim_op(g_btc[:, t, :], 1) + + x = current_input + ct = ct.asnumpy() + x = self.first_conv.incremental_forward(x) + + skips = 0 + for f in self.conv_layers: + x, h = f.incremental_forward(x, ct, gt) + skips += h + skips *= self.factor + x = skips + + for f in self.last_conv_layers: + try: + x = f.incremental_forward(x) + except AttributeError: + x = self.relu_numpy(x) + + # Generate next input by sampling + if self.scalar_input: + if self.output_distribution == "Logistic": + x = sample_from_discretized_mix_logistic(x.reshape((B, -1, 1)), log_scale_min=log_scale_min) + + elif self.output_distribution == "Normal": + x = sample_from_mix_gaussian(x.reshape((B, -1, 1)), log_scale_min=log_scale_min) + else: + assert False + else: + x = self.softmax_numpy(np.reshape(x, (B, -1))) if softmax else np.reshape(x, (B, -1)) + if quantize: + x = sample_from_mix_onehotcategorical(x) + + outputs += [x] + # T x B x C + outputs = np.stack(outputs, 0) + # B x C x T + outputs = np.transpose(outputs, (1, 2, 0)) + self.clear_buffer() + return outputs + + def clear_buffer(self): + """clear buffer""" + self.first_conv.clear_buffer() + for f in self.conv_layers: + f.clear_buffer() + for f in self.last_conv_layers: + try: + f.clear_buffer() + except AttributeError: + pass diff --git a/HuaWeiExperiment/wavenet/train.log b/HuaWeiExperiment/wavenet/train.log new file mode 100644 index 00000000..9837edf7 --- /dev/null +++ b/HuaWeiExperiment/wavenet/train.log @@ -0,0 +1 @@ +bash: CUDA_VISIBLE_DEVICES: command not found diff --git a/HuaWeiExperiment/wavenet/train.py b/HuaWeiExperiment/wavenet/train.py new file mode 100644 index 00000000..cda758e3 --- /dev/null +++ b/HuaWeiExperiment/wavenet/train.py @@ -0,0 +1,188 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Train_criteo. +""" +import os +from os.path import join +import json +import argparse +from warnings import warn +from hparams import hparams, hparams_debug_string + +import mindspore +from mindspore import context, Tensor +from mindspore.context import ParallelMode +from mindspore.communication.management import init, get_rank +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.nn.optim import Adam # Momentum, Adagrad, SGD +from mindspore.nn import TrainOneStepCell +from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell +from mindspore.train import Model +from mindspore.train.callback import SummaryCollector +from src.lr_generator import get_lr, get_lrv2 +from src.dataset import get_data_loaders +from src.loss import NetWithLossClass, WaveNetTrainOneStepWithLossScaleCell +from src.callback import Monitor +from wavenet_vocoder import WaveNet +from wavenet_vocoder.util import is_mulaw_quantize, is_scalar_input + +mindspore.common.set_seed(1024) + +parser = argparse.ArgumentParser(description='TTS training') +parser.add_argument('--data_path', type=str, required=True, default='', + help='Directory contains preprocessed features.') +parser.add_argument('--preset', type=str, required=True, default='', help='Path of preset parameters (json).') +parser.add_argument('--checkpoint_dir', type=str, default='./checkpoints_test', + help='Directory where to save model checkpoints [default: checkpoints].') +parser.add_argument('--checkpoint', type=str, default='', help='Restore model from checkpoint path if given.') +parser.add_argument('--speaker_id', type=str, default='', + help=' Use specific speaker of data in case for multi-speaker datasets.') +parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU', 'CPU'), + help='run platform, support Ascend, GPU and CPU. Default: GPU') +parser.add_argument('--mode_name', type=str, default='GRAPH', choices=('GRAPH', 'PYNATIVE'), help='Choose Mode') +parser.add_argument('--is_distributed', action="store_true", default=False, help='Distributed training') +args = parser.parse_args() + +if __name__ == '__main__': + + # init context + target = args.platform + if args.mode_name == 'GRAPH': + context.set_context(mode=context.GRAPH_MODE, device_target=target, save_graphs=False) + else: + context.set_context(mode=context.PYNATIVE_MODE, device_target=target, save_graphs=False) + + if target == 'Ascend': + rank_id = int(os.getenv('RANK_ID')) + group_size = int(os.getenv('RANK_SIZE')) + device_id = int(os.getenv("DEVICE_ID")) + context.set_context(device_id=device_id) + else: + rank_id = 0 + group_size = 1 + device_id = 0 + + if args.is_distributed: + context.reset_auto_parallel_context() + # Ascend + if target == 'Ascend': + context.set_auto_parallel_context(device_num=group_size, parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True, parameter_broadcast=True) + # GPU + else: + context.set_auto_parallel_context(device_num=group_size, parallel_mode=ParallelMode.DATA_PARALLEL, + gradients_mean=True) + init() + speaker_id = int(args.speaker_id) if args.speaker_id != '' else None + if args.preset is not None: + with open(args.preset) as f: + hparams.parse_json(f.read()) + + assert hparams.name == "wavenet_vocoder" + print(hparams_debug_string()) + fs = hparams.sample_rate + os.makedirs(args.checkpoint_dir, exist_ok=True) + + output_json_path = join(args.checkpoint_dir, "hparams.json") + with open(output_json_path, "w") as f: + json.dump(hparams.values(), f, indent=2) + + data_loaders = get_data_loaders(args.data_path, args.speaker_id, hparams=hparams, rank_id=rank_id, + group_size=group_size) + step_size_per_epoch = data_loaders.get_dataset_size() + + if is_mulaw_quantize(hparams.input_type): + if hparams.out_channels != hparams.quantize_channels: + raise RuntimeError( + "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'") + if hparams.upsample_conditional_features and hparams.cin_channels < 0: + s = "Upsample conv layers were specified while local conditioning disabled. " + s += "Notice that upsample conv layers will never be used." + warn(s) + + upsample_params = hparams.upsample_params + upsample_params["cin_channels"] = hparams.cin_channels + upsample_params["cin_pad"] = hparams.cin_pad + model = WaveNet( + out_channels=hparams.out_channels, + layers=hparams.layers, + stacks=hparams.stacks, + residual_channels=hparams.residual_channels, + gate_channels=hparams.gate_channels, + skip_out_channels=hparams.skip_out_channels, + cin_channels=hparams.cin_channels, + gin_channels=hparams.gin_channels, + n_speakers=hparams.n_speakers, + dropout=hparams.dropout, + kernel_size=hparams.kernel_size, + cin_pad=hparams.cin_pad, + upsample_conditional_features=hparams.upsample_conditional_features, + upsample_params=upsample_params, + scalar_input=is_scalar_input(hparams.input_type), + output_distribution=hparams.output_distribution, + ) + + loss_net = NetWithLossClass(model, hparams) + if target == 'Ascend': + lr = get_lrv2(hparams.optimizer_params["lr"], hparams.nepochs, step_size_per_epoch, step_size_per_epoch * 30) + lr = Tensor(lr) + + resume_epoch = None + if args.checkpoint != '': + resume_epoch = 600 # pre-trained training epochs + resume_step = int(args.checkpoint.split('_')[-1].split('.')[0]) + lr = lr[resume_epoch * step_size_per_epoch:] + param_dict = load_checkpoint(args.checkpoint) + load_param_into_net(model, param_dict) + print('Successfully loading the pre-trained model') + loss_scale = mindspore.train.loss_scale_manager.DynamicLossScaleManager(2 ** 16, 2, 1000) + scale_update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2 ** 12, + scale_factor=2, + scale_window=1024) + weights = model.trainable_params() + optimizer = Adam(weights, learning_rate=lr) + train_net = WaveNetTrainOneStepWithLossScaleCell(loss_net, optimizer, scale_update_cell) + else: + lr = get_lr(hparams.optimizer_params["lr"], hparams.nepochs, step_size_per_epoch) + lr = Tensor(lr) + if args.checkpoint != '': + param_dict = load_checkpoint(args.checkpoint) + load_param_into_net(model, param_dict) + print('Successfully loading the pre-trained model') + weights = model.trainable_params() + optimizer = Adam(weights, learning_rate=lr, loss_scale=1024.) + train_net = TrainOneStepCell(loss_net, optimizer) + + if target == 'Ascend': + summary_collector = SummaryCollector(summary_dir='summary_dir/device_{}'.format(device_id), collect_freq=1) + model = Model(train_net) + lr_cb = Monitor(lr) + callback_list = [lr_cb] + if args.is_distributed: + ckpt_path = os.path.join(args.checkpoint_dir, 'ckpt_' + str(get_rank()) + '/') + + else: + ckpt_path = args.checkpoint_dir + config_ck = CheckpointConfig(save_checkpoint_steps=step_size_per_epoch, keep_checkpoint_max=hparams.nepochs) + ckpt_cb = ModelCheckpoint(prefix='wavenet', directory=ckpt_path, config=config_ck) + callback_list.append(ckpt_cb) + if target == 'Ascend': + callback_list.append(summary_collector) + if target == 'Ascend' and resume_epoch is not None: + model.train(hparams.nepochs - resume_epoch, data_loaders, callbacks=callback_list, dataset_sink_mode=False) + else: + model.train(hparams.nepochs, data_loaders, callbacks=callback_list, dataset_sink_mode=False) diff --git a/HuaWeiExperiment/wavenet/train_pytorch.py b/HuaWeiExperiment/wavenet/train_pytorch.py new file mode 100644 index 00000000..8f8e56e4 --- /dev/null +++ b/HuaWeiExperiment/wavenet/train_pytorch.py @@ -0,0 +1,1117 @@ +"""Trainining script for WaveNet vocoder + +usage: train.py [options] + +options: + --dump-root= Directory contains preprocessed features. + --checkpoint-dir= Directory where to save model checkpoints [default: checkpoints]. + --hparams= Hyper parameters [default: ]. + --preset= Path of preset parameters (json). + --checkpoint= Restore model from checkpoint path if given. + --restore-parts= Restore part of the model. + --log-event-path= Log event path. + --reset-optimizer Reset optimizer. + --speaker-id= Use specific speaker of data in case for multi-speaker datasets. + -h, --help Show this help message and exit +""" +from docopt import docopt + +import sys + +import os +from os.path import dirname, join, expanduser, exists +from tqdm import tqdm +from datetime import datetime +import random +import json +from glob import glob + +import numpy as np + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + + +import torch +from torch import nn +from torch.nn import functional as F +from torch import optim +import torch.backends.cudnn as cudnn +from torch.utils import data as data_utils +from torch.utils.data.sampler import Sampler + +import torch.optim.lr_scheduler as lrschedule + + +from nnmnkwii import preprocessing as P +from nnmnkwii.datasets import FileSourceDataset, FileDataSource + +import librosa.display + +from tensorboardX import SummaryWriter +from matplotlib import cm +from warnings import warn + +from wavenet_vocoder import WaveNet +from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_raw, is_scalar_input +from wavenet_vocoder.mixture import discretized_mix_logistic_loss +from wavenet_vocoder.mixture import sample_from_discretized_mix_logistic +from wavenet_vocoder.mixture import mix_gaussian_loss +from wavenet_vocoder.mixture import sample_from_mix_gaussian + +import audio +from hparams import hparams, hparams_debug_string + + +global_step = 0 +global_test_step = 0 +global_epoch = 0 +use_cuda = torch.cuda.is_available() +if use_cuda: + cudnn.benchmark = True + + +def sanity_check(model, c, g): + if model.has_speaker_embedding(): + if g is None: + raise RuntimeError( + "WaveNet expects speaker embedding, but speaker-id is not provided") + else: + if g is not None: + raise RuntimeError( + "WaveNet expects no speaker embedding, but speaker-id is provided") + + if model.local_conditioning_enabled(): + if c is None: + raise RuntimeError("WaveNet expects conditional features, but not given") + else: + if c is not None: + raise RuntimeError("WaveNet expects no conditional features, but given") + + +def maybe_set_epochs_based_on_max_steps(hp, steps_per_epoch): + nepochs = hp.nepochs + max_train_steps = hp.max_train_steps + if max_train_steps is not None: + epochs = int(np.ceil(max_train_steps / steps_per_epoch)) + hp.nepochs = epochs + print("info; Number of epochs is set based on max_train_steps: {}".format(epochs)) + + +def _pad(seq, max_len, constant_values=0): + return np.pad(seq, (0, max_len - len(seq)), + mode='constant', constant_values=constant_values) + + +def _pad_2d(x, max_len, b_pad=0, constant_values=0): + x = np.pad(x, [(b_pad, max_len - len(x) - b_pad), (0, 0)], + mode="constant", constant_values=constant_values) + return x + +# from: https://github.com/keras-team/keras/blob/master/keras/utils/np_utils.py +# to avoid keras dependency + + +def to_categorical(y, num_classes=None, dtype='float32'): + """Converts a class vector (integers) to binary class matrix. + E.g. for use with categorical_crossentropy. + # Arguments + y: class vector to be converted into a matrix + (integers from 0 to num_classes). + num_classes: total number of classes. + dtype: The data type expected by the input, as a string + (`float32`, `float64`, `int32`...) + # Returns + A binary matrix representation of the input. The classes axis + is placed last. + # Example + ```python + # Consider an array of 5 labels out of a set of 3 classes {0, 1, 2}: + > labels + array([0, 2, 1, 2, 0]) + # `to_categorical` converts this into a matrix with as many + # columns as there are classes. The number of rows + # stays the same. + > to_categorical(labels) + array([[ 1., 0., 0.], + [ 0., 0., 1.], + [ 0., 1., 0.], + [ 0., 0., 1.], + [ 1., 0., 0.]], dtype=float32) + ``` + """ + + y = np.array(y, dtype='int') + input_shape = y.shape + if input_shape and input_shape[-1] == 1 and len(input_shape) > 1: + input_shape = tuple(input_shape[:-1]) + y = y.ravel() + if not num_classes: + num_classes = np.max(y) + 1 + n = y.shape[0] + categorical = np.zeros((n, num_classes), dtype=dtype) + categorical[np.arange(n), y] = 1 + output_shape = input_shape + (num_classes,) + categorical = np.reshape(categorical, output_shape) + return categorical + + +# TODO: I know this is too ugly... +class _NPYDataSource(FileDataSource): + def __init__(self, dump_root, col, typ="", speaker_id=None, max_steps=8000, + cin_pad=0, hop_size=256): + self.dump_root = dump_root + self.col = col + self.lengths = [] + self.speaker_id = speaker_id + self.multi_speaker = False + self.speaker_ids = None + self.max_steps = max_steps + self.cin_pad = cin_pad + self.hop_size = hop_size + self.typ = typ + + def collect_files(self): + meta = join(self.dump_root, "train.txt") + if not exists(meta): + paths = sorted(glob(join(self.dump_root, "*-{}.npy".format(self.typ)))) + return paths + + with open(meta, "rb") as f: + lines = f.readlines() + l = lines[0].decode("utf-8").split("|") + assert len(l) == 4 or len(l) == 5 + self.multi_speaker = len(l) == 5 + self.lengths = list( + map(lambda l: int(l.decode("utf-8").split("|")[2]), lines)) + + paths_relative = list(map(lambda l: l.decode("utf-8").split("|")[self.col], lines)) + paths = list(map(lambda f: join(self.dump_root, f), paths_relative)) + + # Exclude small files (assuming lenghts are in frame unit) + # TODO: consider this for multi-speaker + if self.max_steps is not None: + idx = np.array(self.lengths) * self.hop_size > self.max_steps + 2 * self.cin_pad * self.hop_size + if idx.sum() != len(self.lengths): + print("{} short samples are omitted for training.".format(len(self.lengths) - idx.sum())) + self.lengths = list(np.array(self.lengths)[idx]) + paths = list(np.array(paths)[idx]) + + if self.multi_speaker: + speaker_ids = list(map(lambda l: int(l.decode("utf-8").split("|")[-1]), lines)) + self.speaker_ids = speaker_ids + if self.speaker_id is not None: + # Filter by speaker_id + # using multi-speaker dataset as a single speaker dataset + indices = np.array(speaker_ids) == self.speaker_id + paths = list(np.array(paths)[indices]) + self.lengths = list(np.array(self.lengths)[indices]) + # aha, need to cast numpy.int64 to int + self.lengths = list(map(int, self.lengths)) + self.multi_speaker = False + + if self.multi_speaker: + speaker_ids_np = list(np.array(self.speaker_ids)[indices]) + self.speaker_ids = list(map(int, speaker_ids_np)) + assert len(paths) == len(self.speaker_ids) + + return paths + + def collect_features(self, path): + return np.load(path) + + +class RawAudioDataSource(_NPYDataSource): + def __init__(self, dump_root, **kwargs): + super(RawAudioDataSource, self).__init__(dump_root, 0, "wave", **kwargs) + + +class MelSpecDataSource(_NPYDataSource): + def __init__(self, dump_root, **kwargs): + super(MelSpecDataSource, self).__init__(dump_root, 1, "feats", **kwargs) + + +class PartialyRandomizedSimilarTimeLengthSampler(Sampler): + """Partially randomized sampler + + 1. Sort by lengths + 2. Pick a small patch and randomize it + 3. Permutate mini-batches + """ + + def __init__(self, lengths, batch_size=8, batch_group_size=None): + self.lengths, self.sorted_indices = torch.sort(torch.LongTensor(lengths)) + + self.batch_size = batch_size + if batch_group_size is None: + batch_group_size = min(batch_size * 8, len(self.lengths)) + if batch_group_size % batch_size != 0: + batch_group_size -= batch_group_size % batch_size + + self.batch_group_size = batch_group_size + assert batch_group_size % batch_size == 0 + + def __iter__(self): + indices = self.sorted_indices.numpy() + batch_group_size = self.batch_group_size + s, e = 0, 0 + bins = [] + for i in range(len(indices) // batch_group_size): + s = i * batch_group_size + e = s + batch_group_size + group = indices[s:e] + random.shuffle(group) + bins += [group] + + # Permutate batches + random.shuffle(bins) + binned_idx = np.stack(bins).reshape(-1) + + # Handle last elements + s += batch_group_size + if s < len(indices): + last_bin = indices[len(binned_idx):] + random.shuffle(last_bin) + binned_idx = np.concatenate([binned_idx, last_bin]) + + return iter(torch.tensor(binned_idx).long()) + + def __len__(self): + return len(self.sorted_indices) + + +class PyTorchDataset(object): + def __init__(self, X, Mel): + self.X = X + self.Mel = Mel + # alias + self.multi_speaker = X.file_data_source.multi_speaker + + def __getitem__(self, idx): + if self.Mel is None: + mel = None + else: + mel = self.Mel[idx] + + raw_audio = self.X[idx] + if self.multi_speaker: + speaker_id = self.X.file_data_source.speaker_ids[idx] + else: + speaker_id = None + + # (x,c,g) + return raw_audio, mel, speaker_id + + def __len__(self): + return len(self.X) + + +def sequence_mask(sequence_length, max_len=None): + if max_len is None: + max_len = sequence_length.data.max() + batch_size = sequence_length.size(0) + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) + if sequence_length.is_cuda: + seq_range_expand = seq_range_expand.cuda() + seq_length_expand = sequence_length.unsqueeze(1) \ + .expand_as(seq_range_expand) + return (seq_range_expand < seq_length_expand).float() + + +# https://discuss.pytorch.org/t/how-to-apply-exponential-moving-average-decay-for-variables/10856/4 +# https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage +class ExponentialMovingAverage(object): + def __init__(self, decay): + self.decay = decay + self.shadow = {} + + def register(self, name, val): + self.shadow[name] = val.clone() + + def update(self, name, x): + assert name in self.shadow + update_delta = self.shadow[name] - x + self.shadow[name] -= (1.0 - self.decay) * update_delta + + +def clone_as_averaged_model(device, model, ema): + assert ema is not None + averaged_model = build_model().to(device) + averaged_model.load_state_dict(model.state_dict()) + for name, param in averaged_model.named_parameters(): + if name in ema.shadow: + param.data = ema.shadow[name].clone() + return averaged_model + + +class MaskedCrossEntropyLoss(nn.Module): + def __init__(self): + super(MaskedCrossEntropyLoss, self).__init__() + self.criterion = nn.CrossEntropyLoss(reduction='none') + + def forward(self, input, target, lengths=None, mask=None, max_len=None): + if lengths is None and mask is None: + raise RuntimeError("Should provide either lengths or mask") + + # (B, T, 1) + if mask is None: + mask = sequence_mask(lengths, max_len).unsqueeze(-1) + + # (B, T, D) + mask_ = mask.expand_as(target) + losses = self.criterion(input, target) + return ((losses * mask_).sum()) / mask_.sum() + + +class DiscretizedMixturelogisticLoss(nn.Module): + def __init__(self): + super(DiscretizedMixturelogisticLoss, self).__init__() + + def forward(self, input, target, lengths=None, mask=None, max_len=None): + if lengths is None and mask is None: + raise RuntimeError("Should provide either lengths or mask") + + # (B, T, 1) + if mask is None: + mask = sequence_mask(lengths, max_len).unsqueeze(-1) + + # (B, T, 1) + mask_ = mask.expand_as(target) + + losses = discretized_mix_logistic_loss( + input, target, num_classes=hparams.quantize_channels, + log_scale_min=hparams.log_scale_min, reduce=False) + assert losses.size() == target.size() + return ((losses * mask_).sum()) / mask_.sum() + + +class MixtureGaussianLoss(nn.Module): + def __init__(self): + super(MixtureGaussianLoss, self).__init__() + + def forward(self, input, target, lengths=None, mask=None, max_len=None): + if lengths is None and mask is None: + raise RuntimeError("Should provide either lengths or mask") + + # (B, T, 1) + if mask is None: + mask = sequence_mask(lengths, max_len).unsqueeze(-1) + + # (B, T, 1) + mask_ = mask.expand_as(target) + + losses = mix_gaussian_loss( + input, target, log_scale_min=hparams.log_scale_min, reduce=False) + assert losses.size() == target.size() + return ((losses * mask_).sum()) / mask_.sum() + + +def ensure_divisible(length, divisible_by=256, lower=True): + if length % divisible_by == 0: + return length + if lower: + return length - length % divisible_by + else: + return length + (divisible_by - length % divisible_by) + + +def assert_ready_for_upsampling(x, c, cin_pad): + assert len(x) == (len(c) - 2 * cin_pad) * audio.get_hop_size() + + +def collate_fn(batch): + """Create batch + + Args: + batch(tuple): List of tuples + - x[0] (ndarray,int) : list of (T,) + - x[1] (ndarray,int) : list of (T, D) + - x[2] (ndarray,int) : list of (1,), speaker id + Returns: + tuple: Tuple of batch + - x (FloatTensor) : Network inputs (B, C, T) + - y (LongTensor) : Network targets (B, T, 1) + """ + + local_conditioning = len(batch[0]) >= 2 and hparams.cin_channels > 0 + global_conditioning = len(batch[0]) >= 3 and hparams.gin_channels > 0 + + if hparams.max_time_sec is not None: + max_time_steps = int(hparams.max_time_sec * hparams.sample_rate) + elif hparams.max_time_steps is not None: + max_time_steps = hparams.max_time_steps + else: + max_time_steps = None + + # Time resolution adjustment + cin_pad = hparams.cin_pad + if local_conditioning: + new_batch = [] + for idx in range(len(batch)): + x, c, g = batch[idx] + if hparams.upsample_conditional_features: + assert_ready_for_upsampling(x, c, cin_pad=0) + if max_time_steps is not None: + max_steps = ensure_divisible(max_time_steps, audio.get_hop_size(), True) + if len(x) > max_steps: + max_time_frames = max_steps // audio.get_hop_size() + s = np.random.randint(cin_pad, len(c) - max_time_frames - cin_pad) + ts = s * audio.get_hop_size() + x = x[ts:ts + audio.get_hop_size() * max_time_frames] + c = c[s - cin_pad:s + max_time_frames + cin_pad, :] + assert_ready_for_upsampling(x, c, cin_pad=cin_pad) + else: + x, c = audio.adjust_time_resolution(x, c) + if max_time_steps is not None and len(x) > max_time_steps: + s = np.random.randint(cin_pad, len(x) - max_time_steps - cin_pad) + x = x[s:s + max_time_steps] + c = c[s - cin_pad:s + max_time_steps + cin_pad, :] + assert len(x) == len(c) + new_batch.append((x, c, g)) + batch = new_batch + else: + new_batch = [] + for idx in range(len(batch)): + x, c, g = batch[idx] + x = audio.trim(x) + if max_time_steps is not None and len(x) > max_time_steps: + s = np.random.randint(0, len(x) - max_time_steps) + if local_conditioning: + x, c = x[s:s + max_time_steps], c[s:s + max_time_steps, :] + else: + x = x[s:s + max_time_steps] + new_batch.append((x, c, g)) + batch = new_batch + + # Lengths + input_lengths = [len(x[0]) for x in batch] + max_input_len = max(input_lengths) + + # (B, T, C) + # pad for time-axis + if is_mulaw_quantize(hparams.input_type): + padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1) + x_batch = np.array([_pad_2d(to_categorical( + x[0], num_classes=hparams.quantize_channels), + max_input_len, 0, padding_value) for x in batch], dtype=np.float32) + else: + x_batch = np.array([_pad_2d(x[0].reshape(-1, 1), max_input_len) + for x in batch], dtype=np.float32) + assert len(x_batch.shape) == 3 + + # (B, T) + if is_mulaw_quantize(hparams.input_type): + padding_value = P.mulaw_quantize(0, mu=hparams.quantize_channels - 1) + y_batch = np.array([_pad(x[0], max_input_len, constant_values=padding_value) + for x in batch], dtype=np.int) + else: + y_batch = np.array([_pad(x[0], max_input_len) for x in batch], dtype=np.float32) + assert len(y_batch.shape) == 2 + + # (B, T, D) + if local_conditioning: + max_len = max([len(x[1]) for x in batch]) + c_batch = np.array([_pad_2d(x[1], max_len) for x in batch], dtype=np.float32) + assert len(c_batch.shape) == 3 + # (B x C x T) + c_batch = torch.FloatTensor(c_batch).transpose(1, 2).contiguous() + else: + c_batch = None + + if global_conditioning: + g_batch = torch.LongTensor([x[2] for x in batch]) + else: + g_batch = None + + # Covnert to channel first i.e., (B, C, T) + x_batch = torch.FloatTensor(x_batch).transpose(1, 2).contiguous() + # Add extra axis + if is_mulaw_quantize(hparams.input_type): + y_batch = torch.LongTensor(y_batch).unsqueeze(-1).contiguous() + else: + y_batch = torch.FloatTensor(y_batch).unsqueeze(-1).contiguous() + + input_lengths = torch.LongTensor(input_lengths) + + return x_batch, y_batch, c_batch, g_batch, input_lengths + + +def time_string(): + return datetime.now().strftime('%Y-%m-%d %H:%M') + + +def save_waveplot(path, y_hat, y_target): + sr = hparams.sample_rate + + plt.figure(figsize=(16, 6)) + plt.subplot(2, 1, 1) + librosa.display.waveplot(y_target, sr=sr) + plt.subplot(2, 1, 2) + librosa.display.waveplot(y_hat, sr=sr) + plt.tight_layout() + plt.savefig(path, format="png") + plt.close() + + +def eval_model(global_step, writer, device, model, y, c, g, input_lengths, eval_dir, ema=None): + if ema is not None: + print("Using averaged model for evaluation") + model = clone_as_averaged_model(device, model, ema) + model.make_generation_fast_() + + model.eval() + idx = np.random.randint(0, len(y)) + length = input_lengths[idx].data.cpu().item() + + # (T,) + y_target = y[idx].view(-1).data.cpu().numpy()[:length] + + if c is not None: + if hparams.upsample_conditional_features: + c = c[idx, :, :length // audio.get_hop_size() + hparams.cin_pad * 2].unsqueeze(0) + else: + c = c[idx, :, :length].unsqueeze(0) + assert c.dim() == 3 + print("Shape of local conditioning features: {}".format(c.size())) + if g is not None: + # TODO: test + g = g[idx] + print("Shape of global conditioning features: {}".format(g.size())) + + # Dummy silence + if is_mulaw_quantize(hparams.input_type): + initial_value = P.mulaw_quantize(0, hparams.quantize_channels - 1) + elif is_mulaw(hparams.input_type): + initial_value = P.mulaw(0.0, hparams.quantize_channels) + else: + initial_value = 0.0 + + # (C,) + if is_mulaw_quantize(hparams.input_type): + initial_input = to_categorical( + initial_value, num_classes=hparams.quantize_channels).astype(np.float32) + initial_input = torch.from_numpy(initial_input).view( + 1, 1, hparams.quantize_channels) + else: + initial_input = torch.zeros(1, 1, 1).fill_(initial_value) + initial_input = initial_input.to(device) + + # Run the model in fast eval mode + with torch.no_grad(): + y_hat = model.incremental_forward( + initial_input, c=c, g=g, T=length, softmax=True, quantize=True, tqdm=tqdm, + log_scale_min=hparams.log_scale_min) + + if is_mulaw_quantize(hparams.input_type): + y_hat = y_hat.max(1)[1].view(-1).long().cpu().data.numpy() + y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1) + y_target = P.inv_mulaw_quantize(y_target, hparams.quantize_channels - 1) + elif is_mulaw(hparams.input_type): + y_hat = P.inv_mulaw(y_hat.view(-1).cpu().data.numpy(), hparams.quantize_channels) + y_target = P.inv_mulaw(y_target, hparams.quantize_channels) + else: + y_hat = y_hat.view(-1).cpu().data.numpy() + + # Save audio + os.makedirs(eval_dir, exist_ok=True) + path = join(eval_dir, "step{:09d}_predicted.wav".format(global_step)) + librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) + path = join(eval_dir, "step{:09d}_target.wav".format(global_step)) + librosa.output.write_wav(path, y_target, sr=hparams.sample_rate) + + # save figure + path = join(eval_dir, "step{:09d}_waveplots.png".format(global_step)) + save_waveplot(path, y_hat, y_target) + + +def save_states(global_step, writer, y_hat, y, input_lengths, checkpoint_dir=None): + print("Save intermediate states at step {}".format(global_step)) + idx = np.random.randint(0, len(y_hat)) + length = input_lengths[idx].data.cpu().item() + + # (B, C, T) + if y_hat.dim() == 4: + y_hat = y_hat.squeeze(-1) + + if is_mulaw_quantize(hparams.input_type): + # (B, T) + y_hat = F.softmax(y_hat, dim=1).max(1)[1] + + # (T,) + y_hat = y_hat[idx].data.cpu().long().numpy() + y = y[idx].view(-1).data.cpu().long().numpy() + + y_hat = P.inv_mulaw_quantize(y_hat, hparams.quantize_channels - 1) + y = P.inv_mulaw_quantize(y, hparams.quantize_channels - 1) + else: + # (B, T) + if hparams.output_distribution == "Logistic": + y_hat = sample_from_discretized_mix_logistic( + y_hat, log_scale_min=hparams.log_scale_min) + elif hparams.output_distribution == "Normal": + y_hat = sample_from_mix_gaussian( + y_hat, log_scale_min=hparams.log_scale_min) + else: + assert False + + # (T,) + y_hat = y_hat[idx].view(-1).data.cpu().numpy() + y = y[idx].view(-1).data.cpu().numpy() + + if is_mulaw(hparams.input_type): + y_hat = P.inv_mulaw(y_hat, hparams.quantize_channels) + y = P.inv_mulaw(y, hparams.quantize_channels) + + # Mask by length + y_hat[length:] = 0 + y[length:] = 0 + + # Save audio + audio_dir = join(checkpoint_dir, "intermediate", "audio") + os.makedirs(audio_dir, exist_ok=True) + path = join(audio_dir, "step{:09d}_predicted.wav".format(global_step)) + librosa.output.write_wav(path, y_hat, sr=hparams.sample_rate) + path = join(audio_dir, "step{:09d}_target.wav".format(global_step)) + librosa.output.write_wav(path, y, sr=hparams.sample_rate) + +# workaround for https://github.com/pytorch/pytorch/issues/15716 +# the idea is to return outputs and replicas explicitly, so that making pytorch +# not to release the nodes (this is a pytorch bug though) + + +def data_parallel_workaround(model, input): + device_ids = list(range(torch.cuda.device_count())) + output_device = device_ids[0] + replicas = torch.nn.parallel.replicate(model, device_ids) + inputs = torch.nn.parallel.scatter(input, device_ids) + replicas = replicas[:len(inputs)] + outputs = torch.nn.parallel.parallel_apply(replicas, inputs) + y_hat = torch.nn.parallel.gather(outputs, output_device) + return y_hat, outputs, replicas + + +def __train_step(device, phase, epoch, global_step, global_test_step, + model, optimizer, writer, criterion, + x, y, c, g, input_lengths, + checkpoint_dir, eval_dir=None, do_eval=False, ema=None): + sanity_check(model, c, g) + + # x : (B, C, T) + # y : (B, T, 1) + # c : (B, C, T) + # g : (B,) + train = (phase == "train_no_dev") + clip_thresh = hparams.clip_thresh + if train: + model.train() + step = global_step + else: + model.eval() + step = global_test_step + + # Learning rate schedule + current_lr = hparams.optimizer_params["lr"] + if train and hparams.lr_schedule is not None: + lr_schedule_f = getattr(lrschedule, hparams.lr_schedule) + current_lr = lr_schedule_f( + hparams.optimizer_params["lr"], step, **hparams.lr_schedule_kwargs) + for param_group in optimizer.param_groups: + param_group['lr'] = current_lr + optimizer.zero_grad() + + # Prepare data + x, y = x.to(device), y.to(device) + input_lengths = input_lengths.to(device) + c = c.to(device) if c is not None else None + g = g.to(device) if g is not None else None + + # (B, T, 1) + mask = sequence_mask(input_lengths, max_len=x.size(-1)).unsqueeze(-1) + mask = mask[:, 1:, :] + + # Apply model: Run the model in regular eval mode + # NOTE: softmax is handled in F.cross_entrypy_loss + # y_hat: (B x C x T) + + if use_cuda: + # multi gpu support + # you must make sure that batch size % num gpu == 0 + y_hat, _outputs, _replicas = data_parallel_workaround(model, (x, c, g, False)) + else: + y_hat = model(x, c, g, False) + + if is_mulaw_quantize(hparams.input_type): + # wee need 4d inputs for spatial cross entropy loss + # (B, C, T, 1) + y_hat = y_hat.unsqueeze(-1) + loss = criterion(y_hat[:, :, :-1, :], y[:, 1:, :], mask=mask) + else: + loss = criterion(y_hat[:, :, :-1], y[:, 1:, :], mask=mask) + + if train and step > 0 and step % hparams.checkpoint_interval == 0: + save_states(step, writer, y_hat, y, input_lengths, checkpoint_dir) + save_checkpoint(device, model, optimizer, step, checkpoint_dir, epoch, ema) + + if do_eval: + # NOTE: use train step (i.e., global_step) for filename + eval_model(global_step, writer, device, model, y, c, g, input_lengths, eval_dir, ema) + + # Update + if train: + loss.backward() + if clip_thresh > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), clip_thresh) + optimizer.step() + # update moving average + if ema is not None: + for name, param in model.named_parameters(): + if name in ema.shadow: + ema.update(name, param.data) + + # Logs + writer.add_scalar("{} loss".format(phase), float(loss.item()), step) + if train: + if clip_thresh > 0: + writer.add_scalar("gradient norm", grad_norm, step) + writer.add_scalar("learning rate", current_lr, step) + + return loss.item() + + +def train_loop(device, model, data_loaders, optimizer, writer, checkpoint_dir=None): + if is_mulaw_quantize(hparams.input_type): + criterion = MaskedCrossEntropyLoss() + else: + if hparams.output_distribution == "Logistic": + criterion = DiscretizedMixturelogisticLoss() + elif hparams.output_distribution == "Normal": + criterion = MixtureGaussianLoss() + else: + raise RuntimeError( + "Not supported output distribution type: {}".format( + hparams.output_distribution)) + + if hparams.exponential_moving_average: + ema = ExponentialMovingAverage(hparams.ema_decay) + for name, param in model.named_parameters(): + if param.requires_grad: + ema.register(name, param.data) + else: + ema = None + + global global_step, global_epoch, global_test_step + while global_epoch < hparams.nepochs: + for phase, data_loader in data_loaders.items(): + train = (phase == "train_no_dev") + running_loss = 0. + test_evaluated = False + for step, (x, y, c, g, input_lengths) in tqdm(enumerate(data_loader)): + # Whether to save eval (i.e., online decoding) result + do_eval = False + eval_dir = join(checkpoint_dir, "intermediate", "{}_eval".format(phase)) + # Do eval per eval_interval for train + if train and global_step > 0 \ + and global_step % hparams.train_eval_interval == 0: + do_eval = True + # Do eval for test + # NOTE: Decoding WaveNet is quite time consuming, so + # do only once in a single epoch for testset + if not train and not test_evaluated \ + and global_epoch % hparams.test_eval_epoch_interval == 0: + do_eval = True + test_evaluated = True + if do_eval: + print("[{}] Eval at train step {}".format(phase, global_step)) + + # Do step + running_loss += __train_step(device, + phase, global_epoch, global_step, global_test_step, model, + optimizer, writer, criterion, x, y, c, g, input_lengths, + checkpoint_dir, eval_dir, do_eval, ema) + + # update global state + if train: + global_step += 1 + else: + global_test_step += 1 + + if global_step >= hparams.max_train_steps: + print("Training reached max train steps ({}). will exit".format(hparams.max_train_steps)) + return ema + + # log per epoch + averaged_loss = running_loss / len(data_loader) + writer.add_scalar("{} loss (per epoch)".format(phase), + averaged_loss, global_epoch) + print("Step {} [{}] Loss: {}".format( + global_step, phase, running_loss / len(data_loader))) + + global_epoch += 1 + return ema + + +def save_checkpoint(device, model, optimizer, step, checkpoint_dir, epoch, ema=None): + checkpoint_path = join( + checkpoint_dir, "checkpoint_step{:09d}.pth".format(global_step)) + optimizer_state = optimizer.state_dict() if hparams.save_optimizer_state else None + global global_test_step + torch.save({ + "state_dict": model.state_dict(), + "optimizer": optimizer_state, + "global_step": step, + "global_epoch": epoch, + "global_test_step": global_test_step, + }, checkpoint_path) + print("Saved checkpoint:", checkpoint_path) + + import shutil + latest_pth = join(checkpoint_dir, "checkpoint_latest.pth") + shutil.copyfile(checkpoint_path, latest_pth) + + if ema is not None: + averaged_model = clone_as_averaged_model(device, model, ema) + checkpoint_path = join( + checkpoint_dir, "checkpoint_step{:09d}_ema.pth".format(global_step)) + torch.save({ + "state_dict": averaged_model.state_dict(), + "optimizer": optimizer_state, + "global_step": step, + "global_epoch": epoch, + "global_test_step": global_test_step, + }, checkpoint_path) + print("Saved averaged checkpoint:", checkpoint_path) + + latest_pth = join(checkpoint_dir, "checkpoint_latest_ema.pth") + shutil.copyfile(checkpoint_path, latest_pth) + + +def build_model(): + if is_mulaw_quantize(hparams.input_type): + if hparams.out_channels != hparams.quantize_channels: + raise RuntimeError( + "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'") + if hparams.upsample_conditional_features and hparams.cin_channels < 0: + s = "Upsample conv layers were specified while local conditioning disabled. " + s += "Notice that upsample conv layers will never be used." + warn(s) + + upsample_params = hparams.upsample_params + upsample_params["cin_channels"] = hparams.cin_channels + upsample_params["cin_pad"] = hparams.cin_pad + model = WaveNet( + out_channels=hparams.out_channels, + layers=hparams.layers, + stacks=hparams.stacks, + residual_channels=hparams.residual_channels, + gate_channels=hparams.gate_channels, + skip_out_channels=hparams.skip_out_channels, + cin_channels=hparams.cin_channels, + gin_channels=hparams.gin_channels, + n_speakers=hparams.n_speakers, + dropout=hparams.dropout, + kernel_size=hparams.kernel_size, + cin_pad=hparams.cin_pad, + upsample_conditional_features=hparams.upsample_conditional_features, + upsample_params=upsample_params, + scalar_input=is_scalar_input(hparams.input_type), + output_distribution=hparams.output_distribution, + ) + return model + + +def _load(checkpoint_path): + if use_cuda: + checkpoint = torch.load(checkpoint_path) + else: + checkpoint = torch.load(checkpoint_path, + map_location=lambda storage, loc: storage) + return checkpoint + + +def load_checkpoint(path, model, optimizer, reset_optimizer): + global global_step + global global_epoch + global global_test_step + + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + model.load_state_dict(checkpoint["state_dict"]) + if not reset_optimizer: + optimizer_state = checkpoint["optimizer"] + if optimizer_state is not None: + print("Load optimizer state from {}".format(path)) + optimizer.load_state_dict(checkpoint["optimizer"]) + global_step = checkpoint["global_step"] + global_epoch = checkpoint["global_epoch"] + global_test_step = checkpoint.get("global_test_step", 0) + + return model + + +# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3 +def restore_parts(path, model): + print("Restore part of the model from: {}".format(path)) + state = _load(path)["state_dict"] + model_dict = model.state_dict() + valid_state_dict = {k: v for k, v in state.items() if k in model_dict} + + try: + model_dict.update(valid_state_dict) + model.load_state_dict(model_dict) + except RuntimeError as e: + # there should be invalid size of weight(s), so load them per parameter + print(str(e)) + model_dict = model.state_dict() + for k, v in valid_state_dict.items(): + model_dict[k] = v + try: + model.load_state_dict(model_dict) + except RuntimeError as e: + print(str(e)) + warn("{}: may contain invalid size of weight. skipping...".format(k)) + + +def get_data_loaders(dump_root, speaker_id, test_shuffle=True): + data_loaders = {} + local_conditioning = hparams.cin_channels > 0 + + if hparams.max_time_steps is not None: + max_steps = ensure_divisible(hparams.max_time_steps, audio.get_hop_size(), True) + else: + max_steps = None + + for phase in ["train_no_dev", "dev"]: + train = phase == "train_no_dev" + X = FileSourceDataset( + RawAudioDataSource(join(dump_root, phase), speaker_id=speaker_id, + max_steps=max_steps, cin_pad=hparams.cin_pad, + hop_size=audio.get_hop_size())) + if local_conditioning: + Mel = FileSourceDataset( + MelSpecDataSource(join(dump_root, phase), speaker_id=speaker_id, + max_steps=max_steps, cin_pad=hparams.cin_pad, + hop_size=audio.get_hop_size())) + assert len(X) == len(Mel) + print("Local conditioning enabled. Shape of a sample: {}.".format( + Mel[0].shape)) + else: + Mel = None + print("[{}]: length of the dataset is {}".format(phase, len(X))) + + if train: + lengths = np.array(X.file_data_source.lengths) + # Prepare sampler + sampler = PartialyRandomizedSimilarTimeLengthSampler( + lengths, batch_size=hparams.batch_size) + shuffle = False + # make sure that there's no sorting bugs for https://github.com/r9y9/wavenet_vocoder/issues/130 + sampler_idx = np.asarray(sorted(list(map(lambda s: int(s), sampler)))) + assert (sampler_idx == np.arange(len(sampler_idx), dtype=np.int)).all() + else: + sampler = None + shuffle = test_shuffle + + dataset = PyTorchDataset(X, Mel) + data_loader = data_utils.DataLoader( + dataset, batch_size=hparams.batch_size, drop_last=True, + num_workers=hparams.num_workers, sampler=sampler, shuffle=shuffle, + collate_fn=collate_fn, pin_memory=hparams.pin_memory) + + speaker_ids = {} + if X.file_data_source.multi_speaker: + for idx, (x, c, g) in enumerate(dataset): + if g is not None: + try: + speaker_ids[g] += 1 + except KeyError: + speaker_ids[g] = 1 + if len(speaker_ids) > 0: + print("Speaker stats:", speaker_ids) + + data_loaders[phase] = data_loader + + return data_loaders + + +if __name__ == "__main__": + args = docopt(__doc__) + print("Command line args:\n", args) + checkpoint_dir = args["--checkpoint-dir"] + checkpoint_path = args["--checkpoint"] + checkpoint_restore_parts = args["--restore-parts"] + speaker_id = args["--speaker-id"] + speaker_id = int(speaker_id) if speaker_id is not None else None + preset = args["--preset"] + + dump_root = args["--dump-root"] + if dump_root is None: + dump_root = join(dirname(__file__), "data", "ljspeech") + + log_event_path = args["--log-event-path"] + reset_optimizer = args["--reset-optimizer"] + + # Load preset if specified + if preset is not None: + with open(preset) as f: + hparams.parse_json(f.read()) + # Override hyper parameters + hparams.parse(args["--hparams"]) + assert hparams.name == "wavenet_vocoder" + print(hparams_debug_string()) + + fs = hparams.sample_rate + + os.makedirs(checkpoint_dir, exist_ok=True) + + output_json_path = join(checkpoint_dir, "hparams.json") + with open(output_json_path, "w") as f: + json.dump(hparams.values(), f, indent=2) + + # Dataloader setup + data_loaders = get_data_loaders(dump_root, speaker_id, test_shuffle=True) + + maybe_set_epochs_based_on_max_steps(hparams, len(data_loaders["train_no_dev"])) + + device = torch.device("cuda" if use_cuda else "cpu") + + # Model + model = build_model().to(device) + + receptive_field = model.receptive_field + print("Receptive field (samples / ms): {} / {}".format( + receptive_field, receptive_field / fs * 1000)) + + from torch import optim + Optimizer = getattr(optim, hparams.optimizer) + optimizer = Optimizer(model.parameters(), **hparams.optimizer_params) + + if checkpoint_restore_parts is not None: + restore_parts(checkpoint_restore_parts, model) + + # Load checkpoints + if checkpoint_path is not None: + load_checkpoint(checkpoint_path, model, optimizer, reset_optimizer) + + # Setup summary writer for tensorboard + if log_event_path is None: + log_event_path = "log/run-test" + str(datetime.now()).replace(" ", "_") + print("TensorBoard event log path: {}".format(log_event_path)) + writer = SummaryWriter(log_dir=log_event_path) + + # Train! + ema = None + try: + ema = train_loop(device, model, data_loaders, optimizer, writer, + checkpoint_dir=checkpoint_dir) + except KeyboardInterrupt: + print("Interrupted!") + pass + finally: + save_checkpoint( + device, model, optimizer, global_step, checkpoint_dir, global_epoch, ema) + + print("Finished") + + sys.exit(0) diff --git a/HuaWeiExperiment/wavenet/utils/parse_options.sh b/HuaWeiExperiment/wavenet/utils/parse_options.sh new file mode 100644 index 00000000..335e69e9 --- /dev/null +++ b/HuaWeiExperiment/wavenet/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### Now we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/HuaWeiExperiment/wavenet/wavenet_vocoder/__init__.py b/HuaWeiExperiment/wavenet/wavenet_vocoder/__init__.py new file mode 100644 index 00000000..e34d0b96 --- /dev/null +++ b/HuaWeiExperiment/wavenet/wavenet_vocoder/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""init""" +from __future__ import with_statement, print_function, absolute_import +from .wavenet import WaveNet diff --git a/HuaWeiExperiment/wavenet/wavenet_vocoder/conv.py b/HuaWeiExperiment/wavenet/wavenet_vocoder/conv.py new file mode 100644 index 00000000..85feed02 --- /dev/null +++ b/HuaWeiExperiment/wavenet/wavenet_vocoder/conv.py @@ -0,0 +1,182 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Extended Conv1D.""" + +import math +import numpy as np +from mindspore import nn, Tensor +from mindspore.ops import operations as P +import mindspore.common.dtype as mstype +from mindspore import context + +class Conv1d(nn.Conv1d): + """ + Extended nn.Conv1d to adapt to incremental dilated convolutions. + During training, initial Conv1D is used and during evaluation, incremental_forward is called. + To improve the inference speed, tensor will be converted as numpy and the following calculation is based on numpy. + These operation will be replaced with MindSpore ops in the future. Currently, some operation is not supported by + MindSpore and a mixed use of numpy and MindSpore will take a long time. + + """ + + def __init__(self, *args, **kwargs): + super(Conv1d, self).__init__(*args, **kwargs) + self.clear_buffer() + self._linearized_weight = None + self.transpose_op = P.Transpose() + self.reshape_op = P.Reshape() + self.squeeze_op = P.Squeeze(-2) + self.zeros = P.Zeros() + self.concat_op = P.Concat(axis=1) + self.matmul = P.MatMul(transpose_b=True) + self.bias_add = P.BiasAdd() + self.get_weight = None + self.get_bias = None + + def incremental_forward(self, inputs, is_numpy=True): + if is_numpy: + return self.incremental_forward_numpy(inputs) + return self.incremental_forward_pynative(inputs) + + def incremental_forward_pynative(self, inputs): + """ + Incremental forward. + + Args: + inputs: B x T x C + + Returns: + ndarray + + """ + # input: (B, T, C) + if self.training: + raise RuntimeError('incremental_forward only supports eval mode') + + if self.get_weight is None: + self.get_weight = self._get_linearized_weight() + + if self.get_bias is None and self.bias is not None: + self.get_bias = self.bias + + # Note mindspore uses Conv2D to construct Conv1D + kw = self.kernel_size[1] + dilation = self.dilation[1] + + bsz = inputs.shape[0] # input: bsz x len x dim + if kw > 1: + if self.input_buffer is None: + init_buffer = self.zeros((bsz, kw + (kw - 1) * (dilation - 1), inputs.shape[2]), mstype.float32) + self.input_buffer = self.concat_op((init_buffer[:, 1:, :], inputs[:, 0:1, :])) + else: + # shift buffer + self.input_buffer = self.concat_op((self.input_buffer[:, 1:, :], inputs[:, 0:1, :])) + inputs = self.input_buffer + if dilation > 1: + if context.get_context("device_target") == "CPU": + inputs = self.transpose_op(inputs, (1, 0, 2)) + inputs = inputs[0::dilation, :, :] + inputs = self.transpose_op(inputs, (1, 0, 2)) + else: + inputs = inputs[:, 0::dilation, :] + + output = self.matmul(self.reshape_op(inputs, (bsz, -1)), self.get_weight) + if self.bias is not None: + output = self.bias_add(output, self.bias) + return self.reshape_op(output, (bsz, 1, -1)) + + def incremental_forward_numpy(self, inputs): + """ + Incremental forward. + + Args: + inputs: B x T x C + + Returns: + ndarray + + """ + # input: (B, T, C) + if self.training: + raise RuntimeError('incremental_forward only supports eval mode') + + if self.get_weight is None: + weight = self._get_linearized_weight() + self.get_weight = weight.asnumpy() + + if self.get_bias is None and self.bias is not None: + bias = self.bias + self.get_bias = bias.asnumpy() + + # Note mindspore uses Conv2D to construct Conv1D + kw = self.kernel_size[1] + dilation = self.dilation[1] + + bsz = inputs.shape[0] # input: bsz x len x dim + if kw > 1: + if self.input_buffer is None: + self.input_buffer = np.zeros((bsz, kw + (kw - 1) * (dilation - 1), inputs.shape[2]), dtype=np.float32) + else: + # shift buffer + self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :] + # append next + self.input_buffer[:, -1, :] = inputs[:, -1, :] + inputs = self.input_buffer + if dilation > 1: + inputs = inputs[:, 0::dilation, :] + output = inputs.reshape(bsz, -1).dot(self.get_weight.T) + if self.bias is not None: + output = output + np.expand_dims(self.get_bias, 0) + return np.reshape(output, (bsz, 1, -1)) + + def clear_buffer(self): + self.input_buffer = None + + def _get_linearized_weight(self): + """ + get linearized weight + """ + weight = self.squeeze_op(self.weight) + if self._linearized_weight is None: + # Note mindspore uses Conv2D to construct Conv1D + kw = self.kernel_size[1] + if weight.shape == (self.out_channels, self.in_channels, kw): + weight = self.transpose_op(weight, (0, 2, 1)) + else: + weight = self.transpose_op(weight, (2, 0, 1)) + self._linearized_weight = self.reshape_op(weight, (self.out_channels, -1)) + return self._linearized_weight + + def _clear_linearized_weight(self, *args): + self._linearized_weight = None + + def _initialize_weights(self): + """ + weight initialization + """ + self.init_parameters_data() + std_mul = 4.0 + for _, m in self.cells_and_names(): + if isinstance(m, nn.Conv1d): + std = math.sqrt((std_mul * 0.1) / (m.kernel_size[1] * self.in_channels)) + m.weight.set_data(Tensor(np.random.normal(0, std, m.weight.data.shape).astype("float32"))) + if m.bias is not None: + m.bias.set_data( + Tensor(np.zeros(m.bias.data.shape, dtype="float32"))) + elif isinstance(m, nn.BatchNorm2d): + m.gamma.set_data( + Tensor(np.ones(m.gamma.data.shape, dtype="float32"))) + m.beta.set_data( + Tensor(np.zeros(m.beta.data.shape, dtype="float32"))) diff --git a/HuaWeiExperiment/wavenet/wavenet_vocoder/mixture.py b/HuaWeiExperiment/wavenet/wavenet_vocoder/mixture.py new file mode 100644 index 00000000..a594fdd4 --- /dev/null +++ b/HuaWeiExperiment/wavenet/wavenet_vocoder/mixture.py @@ -0,0 +1,386 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Loss function for training and sample function for testing. +""" +import numpy as np +import mindspore as ms +from mindspore import Tensor +import mindspore.nn as nn +import mindspore.ops as P +from mindspore import context + + +class log_sum_exp(nn.Cell): + """Numerically stable log_sum_exp + """ + + def __init__(self): + super(log_sum_exp, self).__init__() + self.maxi = P.ReduceMax() + self.maxi_dim = P.ReduceMax(keep_dims=True) + self.log = P.Log() + self.sums = P.ReduceSum() + self.exp = P.Exp() + + def construct(self, x): + axis = len(x.shape) - 1 + m = self.maxi(x, axis) + m2 = self.maxi_dim(x, axis) + return m + self.log(self.sums(self.exp(x - m2), axis)) + + +class log_softmax(nn.Cell): + """ + replacement of P.LogSoftmax(-1) in CPU mode + only support x.shape == 2 or 3 + """ + + def __init__(self): + super(log_softmax, self).__init__() + self.maxi = P.ReduceMax() + self.log = P.Log() + self.sums = P.ReduceSum() + self.exp = P.Exp() + self.axis = -1 + self.concat = P.Concat(-1) + self.expanddims = P.ExpandDims() + + def construct(self, x): + """ + + Args: + x (Tensor): input + + Returns: + Tensor: log_softmax of input + + """ + c = self.maxi(x, self.axis) + logs, lsm = None, None + if len(x.shape) == 2: + for j in range(x.shape[-1]): + temp = self.expanddims(self.exp(x[:, j] - c), -1) + logs = temp if j == 0 else self.concat((logs, temp)) + sums = self.sums(logs, -1) + for i in range(x.shape[-1]): + temp = self.expanddims(x[:, i] - c - self.log(sums), -1) + lsm = temp if i == 0 else self.concat((lsm, temp)) + return lsm + if len(x.shape) == 3: + for j in range(x.shape[-1]): + temp = self.expanddims(self.exp(x[:, :, j] - c), -1) + logs = temp if j == 0 else self.concat((logs, temp)) + sums = self.sums(logs, -1) + for i in range(x.shape[-1]): + temp = self.expanddims(x[:, :, i] - c - self.log(sums), -1) + lsm = temp if i == 0 else self.concat((lsm, temp)) + return lsm + return None + + +class Stable_softplus(nn.Cell): + """Numerically stable softplus + """ + + def __init__(self): + super(Stable_softplus, self).__init__() + self.log_op = P.Log() + self.abs_op = P.Abs() + self.relu_op = P.ReLU() + self.exp_op = P.Exp() + + def construct(self, x): + return self.log_op(1 + self.exp_op(- self.abs_op(x))) + self.relu_op(x) + + +class discretized_mix_logistic_loss(nn.Cell): + """ + Discretized_mix_logistic_loss + + Args: + num_classes (int): Num_classes + log_scale_min (float): Log scale minimum value + + """ + + def __init__(self, num_classes=256, log_scale_min=-7.0, reduce=True): + super(discretized_mix_logistic_loss, self).__init__() + self.num_classes = num_classes + self.log_scale_min = log_scale_min + self.reduce = reduce + self.transpose_op = P.Transpose() + self.exp = P.Exp() + self.sigmoid = P.Sigmoid() + self.softplus = Stable_softplus() + self.log = P.Log() + self.cast = P.Cast() + self.expand_dims = P.ExpandDims() + self.tile = P.Tile() + self.maximum = P.Maximum() + self.sums = P.ReduceSum() + self.lse = log_sum_exp() + self.reshape = P.Reshape() + self.factor = self.log(Tensor((self.num_classes - 1) / 2, ms.float32)) + self.tensor_one = Tensor(1., ms.float32) + + if context.get_context("device_target") == "CPU": + self.logsoftmax = log_softmax() + else: + self.logsoftmax = P.LogSoftmax(-1) + + def construct(self, y_hat, y): + """ + + Args: + y_hat (Tensor): Predicted distribution + y (Tensor): Target + + Returns: + Tensor: Discretized_mix_logistic_loss + + """ + nr_mix = y_hat.shape[1] // 3 + + # (B x T x C) + y_hat = self.transpose_op(y_hat, (0, 2, 1)) + + # (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix:2 * nr_mix] + min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix)) + log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut) + + # B x T x 1 -> B x T x num_mixtures + y = self.tile(y, (1, 1, nr_mix)) + + centered_y = y - means + inv_stdv = self.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1. / (self.num_classes - 1)) + cdf_plus = self.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1. / (self.num_classes - 1)) + cdf_min = self.sigmoid(min_in) + + log_cdf_plus = plus_in - self.softplus(plus_in) + + log_one_minus_cdf_min = -self.softplus(min_in) + + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + log_pdf_mid = mid_in - log_scales - 2. * self.softplus(mid_in) + + inner_inner_cond = self.cast(cdf_delta > 1e-5, ms.float32) + min_cut2 = 1e-12 * self.tile(self.tensor_one, cdf_delta.shape) + inner_inner_out = inner_inner_cond * \ + self.log(self.maximum(cdf_delta, min_cut2)) + \ + (1. - inner_inner_cond) * (log_pdf_mid - self.factor) + inner_cond = self.cast(y > 0.999, ms.float32) + inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out + cond = self.cast(y < -0.999, ms.float32) + log_probs = cond * log_cdf_plus + (1. - cond) * inner_out + + a, b, c = logit_probs.shape[0], logit_probs.shape[1], logit_probs.shape[2] + logit_probs = self.logsoftmax(self.reshape(logit_probs, (-1, c))) + logit_probs = self.reshape(logit_probs, (a, b, c)) + + log_probs = log_probs + logit_probs + if self.reduce: + return -self.sums(self.lse(log_probs)) + return self.expand_dims(-self.lse(log_probs), -1) + + +def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0): + """ + Sample from discretized mixture of logistic distributions + + Args: + y (ndarray): B x C x T + log_scale_min (float): Log scale minimum value + + Returns: + ndarray + """ + nr_mix = y.shape[1] // 3 + + # B x T x C + y = np.transpose(y, (0, 2, 1)) + logit_probs = y[:, :, :nr_mix] + + temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape) + temp = logit_probs - np.log(- np.log(temp)) + + argmax = np.argmax(temp, axis=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = np.eye(nr_mix)[argmax] + means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1) + log_scales = np.clip(np.sum( + y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1), a_min=log_scale_min, a_max=None) + + u = np.random.uniform(1e-5, 1.0 - 1e-5, means.shape) + x = means + np.exp(log_scales) * (np.log(u) - np.log(1. - u)) + x = np.clip(x, -1., 1.) + return x.astype(np.float32) + + +class mix_gaussian_loss(nn.Cell): + """ + Mix gaussian loss + """ + + def __init__(self, log_scale_min=-7.0, reduce=True): + super(mix_gaussian_loss, self).__init__() + self.log_scale_min = log_scale_min + self.reduce = reduce + self.transpose_op = P.Transpose() + self.maximum = P.Maximum() + self.tile = P.Tile() + self.exp = P.Exp() + self.expand_dims = P.ExpandDims() + self.sums = P.ReduceSum() + self.lse = log_sum_exp() + self.sq = P.Square() + self.sqrt = P.Sqrt() + self.const = P.ScalarToTensor() + self.log = P.Log() + self.tensor_one = Tensor(1., ms.float32) + + if context.get_context("device_target") == "CPU": + self.logsoftmax = log_softmax() + else: + self.logsoftmax = P.LogSoftmax(-1) + + def construct(self, y_hat, y): + """ + + Args: + y_hat (Tensor): Predicted probability + y (Tensor): Target + + Returns: + Tensor: Mix_gaussian_loss + + """ + C = y_hat.shape[1] + if C == 2: + nr_mix = 1 + else: + nr_mix = y_hat.shape[1] // 3 + + # (B x T x C) + y_hat = self.transpose_op(y_hat, (0, 2, 1)) + + if C == 2: + logit_probs = None + means = y_hat[:, :, 0:1] + min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], 1)) + log_scales = self.maximum(y_hat[:, :, 1:2], min_cut) + else: + # (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix:2 * nr_mix] + min_cut = self.log_scale_min * self.tile(self.tensor_one, (y_hat.shape[0], y_hat.shape[1], nr_mix)) + log_scales = self.maximum(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min_cut) + + # B x T x 1 -> B x T x num_mixtures + y = self.tile(y, (1, 1, nr_mix)) + centered_y = y - means + + sd = self.exp(log_scales) + unnormalized_log_prob = -1. * (self.sq(centered_y - 0.)) / (2. * self.sq(sd)) + neg_normalization = -1. * self.log(self.const(2. * np.pi)) / 2. - self.log(sd) + log_probs = unnormalized_log_prob + neg_normalization + + if nr_mix > 1: + log_probs = log_probs + self.logsoftmax(logit_probs) + + if self.reduce: + if nr_mix == 1: + return -self.sums(log_probs) + return -self.sums(self.lse(log_probs)) + if nr_mix == 1: + return -log_probs + return self.expand_dims(-self.lse(log_probs), -1) + + +def sample_from_mix_gaussian(y, log_scale_min=-7.0): + """ + Sample_from_mix_gaussian + + Args: + y (ndarray): B x C x T + + Returns: + ndarray + + """ + C = y.shape[1] + if C == 2: + nr_mix = 1 + else: + nr_mix = y.shape[1] // 3 + + # B x T x C + y = np.transpose(y, (0, 2, 1)) + + if C == 2: + logit_probs = None + else: + logit_probs = y[:, :, :nr_mix] + + if nr_mix > 1: + temp = np.random.uniform(1e-5, 1.0 - 1e-5, logit_probs.shape) + temp = logit_probs - np.log(- np.log(temp)) + argmax = np.argmax(temp, axis=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = np.eye(nr_mix)[argmax] + + means = np.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, axis=-1) + + log_scales = np.sum(y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, axis=-1) + else: + if C == 2: + means, log_scales = y[:, :, 0], y[:, :, 1] + elif C == 3: + means, log_scales = y[:, :, 1], y[:, :, 2] + else: + assert False, "shouldn't happen" + + scales = np.exp(log_scales) + x = np.random.normal(loc=means, scale=scales) + x = np.clip(x, -1., 1.) + return x.astype(np.float32) + + +# self-implemented onehotcategorical distribution +# https://zhuanlan.zhihu.com/p/59550457 +def sample_from_mix_onehotcategorical(x): + """ + Sample_from_mix_onehotcategorical + + Args: + x (ndarray): Predicted softmax probability + + Returns: + ndarray + + """ + pi = np.log(x) + u = np.random.uniform(0, 1, x.shape) + g = -np.log(-np.log(u)) + c = np.argmax(pi + g, axis=1) + return np.array(np.eye(256)[c], dtype=np.float32) diff --git a/HuaWeiExperiment/wavenet/wavenet_vocoder/modules.py b/HuaWeiExperiment/wavenet/wavenet_vocoder/modules.py new file mode 100644 index 00000000..208049c7 --- /dev/null +++ b/HuaWeiExperiment/wavenet/wavenet_vocoder/modules.py @@ -0,0 +1,213 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Modules for WaveNet. +""" +from __future__ import with_statement, print_function, absolute_import +import math +import numpy as np +from wavenet_vocoder import conv +from mindspore import nn +from mindspore.ops import operations as P + + +def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): + m = conv.Conv1d(in_channels, out_channels, kernel_size, **kwargs) + return m + + +def Conv1d1x1(in_channels, out_channels, has_bias=True): + return Conv1d(in_channels, out_channels, kernel_size=1, pad_mode='pad', padding=0, dilation=1, has_bias=has_bias) + + +def Embedding(num_embeddings, embedding_dim, padding_idx, std=0.01): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + return m + + +def _conv1x1_forward(conv_, x, is_incremental, is_numpy=True): + """ + Conv1x1 forward + """ + if is_incremental: + x = conv_.incremental_forward(x, is_numpy=is_numpy) + else: + x = conv_(x) + return x + + +class ResidualConv1dGLU(nn.Cell): + """Residual dilated conv1d with gated activation units + + Args: + residual_channels (int): Residual input / output channels + gate_channels (int): Gated activation channels. + kernel_size (int): Kernel size + skip_out_channels (int): Skip connection channels. If None, it will set to the same as residual_channels. + cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled. + gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled. + dropout (float): Dropout rate. + padding (int): Padding for convolution layers. If None, padding value will be computed according to dilation + and kernel_size. + dilation (int): Dilation factor. + + """ + + def __init__(self, residual_channels=None, gate_channels=None, kernel_size=None, skip_out_channels=None, bias=True, + dropout=1 - 0.95, dilation=1, cin_channels=-1, gin_channels=-1, padding=None, causal=True): + super(ResidualConv1dGLU, self).__init__() + self.dropout = dropout + self.dropout_op = nn.Dropout(p=self.dropout) + self.eval_split_op = P.Split(axis=-1, output_num=2) + self.train_split_op = P.Split(axis=1, output_num=2) + self.tanh = P.Tanh() + self.sigmoid = P.Sigmoid() + self.mul = P.Mul() + self.add = P.Add() + + if skip_out_channels is None: + skip_out_channels = residual_channels + if padding is None: + if causal: + padding = (kernel_size - 1) * dilation + else: + padding = (kernel_size - 1) // 2 * dilation + self.causal = causal + + self.conv = Conv1d(residual_channels, gate_channels, kernel_size, pad_mode='pad', + padding=padding, dilation=dilation, has_bias=bias) + + # local conditioning + if cin_channels > 0: + self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, has_bias=False) + else: + self.conv1x1c = None + + # global conditioning + if gin_channels > 0: + self.conv1x1g = Conv1d(gin_channels, gate_channels, has_bias=False, kernel_size=1, dilation=1) + else: + self.conv1x1g = None + + gate_out_channels = gate_channels // 2 + self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, has_bias=bias) + self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, has_bias=bias) + self.factor = math.sqrt(0.5) + + def construct(self, x, c=None, g=None): + """ + + Args: + x(Tensor): One-hot audio signal, the shape is B x C x T + c(Tensor): local conditional feature, the shape is B x cin_channels x T + g(Tensor): global conditional feature, not used currently + + Returns: + Tensor: Output tensor + + """ + + residual = x + x = self.dropout_op(x) + x = self.conv(x) + # remove future time steps + x = x[:, :, :residual.shape[-1]] if self.causal else x + split_op = self.train_split_op + + a, b = split_op(x) + + # local conditioning + if c is not None: + c = _conv1x1_forward(self.conv1x1c, c, is_incremental=False) + ca, cb = split_op(c) + a, b = a + ca, b + cb + + # global conditioning + if g is not None: + g = _conv1x1_forward(self.conv1x1g, g, is_incremental=False) + ga, gb = self.split(g) + a, b = a + ga, b + gb + + x = self.mul(self.tanh(a), self.sigmoid(b)) + + # For skip connection + s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental=False) + + # For residual connection + x = _conv1x1_forward(self.conv1x1_out, x, is_incremental=False) + + x = self.add(x, residual) * self.factor + return x, s + + def sigmoid_numpy(self, x): + return 1. / (1 + np.exp(-x)) + + def incremental_forward(self, x, c=None, g=None, is_numpy=True): + """ + Incremental forward. Used for inference stage + + Args: + x (Tensor): One-hot audio signal, the shape is B x C x T + c (Tensor): local conditional feature, the shape is B x cin_channels x T + g (Tensor): global conditional feature, not used currently + + Returns: + ndarray + """ + residual = x + x = self.conv.incremental_forward(x, is_numpy=is_numpy) + if is_numpy: + a, b = np.split(x, indices_or_sections=2, axis=-1) + else: + a, b = self.eval_split_op(x) + + # local conditioning + if c is not None: + c = _conv1x1_forward(self.conv1x1c, c, is_incremental=True, is_numpy=is_numpy) + if is_numpy: + ca, cb = np.split(c, indices_or_sections=2, axis=-1) + else: + ca, cb = self.eval_split_op(c) + a, b = a + ca, b + cb + + # global conditioning + if g is not None: + g = _conv1x1_forward(self.conv1x1g, g, is_incremental=True, is_numpy=is_numpy) + if is_numpy: + ga, gb = np.split(g, indices_or_sections=2, axis=-1) + else: + ga, gb = self.eval_split_op(c) + a, b = a + ga, b + gb + + if is_numpy: + x = np.tanh(a) * self.sigmoid_numpy(b) + else: + x = self.mul(self.tanh(a), self.sigmoid(b)) + + # For skip connection + s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental=True, is_numpy=is_numpy) + + # For residual connection + x = _conv1x1_forward(self.conv1x1_out, x, is_incremental=True, is_numpy=is_numpy) + + x = (x + residual) * self.factor + return x, s + + def clear_buffer(self): + """clear buffer""" + for c in [self.conv, self.conv1x1_out, self.conv1x1_skip, + self.conv1x1c, self.conv1x1g]: + if c is not None: + c.clear_buffer() diff --git a/HuaWeiExperiment/wavenet/wavenet_vocoder/tfcompat/__init__.py b/HuaWeiExperiment/wavenet/wavenet_vocoder/tfcompat/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/HuaWeiExperiment/wavenet/wavenet_vocoder/tfcompat/hparam.py b/HuaWeiExperiment/wavenet/wavenet_vocoder/tfcompat/hparam.py new file mode 100644 index 00000000..c428176b --- /dev/null +++ b/HuaWeiExperiment/wavenet/wavenet_vocoder/tfcompat/hparam.py @@ -0,0 +1,726 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hyperparameter values.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import numbers +import re + +import six + +## from tensorflow.contrib.training.python.training import hparam_pb2 +## from tensorflow.python.framework import ops +## from tensorflow.python.util import compat +## from tensorflow.python.util import deprecation + +# Define the regular expression for parsing a single clause of the input +# (delimited by commas). A legal clause looks like: +# []? = +# where is either a single token or [] enclosed list of tokens. +# For example: "var[1] = a" or "x = [1,2,3]" +PARAM_RE = re.compile(r""" + (?P[a-zA-Z][\w\.]*) # variable name: "var" or "x" + (\[\s*(?P\d+)\s*\])? # (optional) index: "1" or None + \s*=\s* + ((?P[^,\[]*) # single value: "a" or None + | + \[(?P[^\]]*)\]) # list of values: None or "1,2,3" + ($|,\s*)""", re.VERBOSE) + + +def _parse_fail(name, var_type, value, values): + """Helper function for raising a value error for bad assignment.""" + raise ValueError( + 'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' % + (name, var_type.__name__, value, values)) + + +def _reuse_fail(name, values): + """Helper function for raising a value error for reuse of name.""" + raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name, + values)) + + +def _process_scalar_value(name, parse_fn, var_type, m_dict, values, + results_dictionary): + """Update results_dictionary with a scalar value. + + Used to update the results_dictionary to be returned by parse_values when + encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) + + Mutates results_dictionary. + + Args: + name: Name of variable in assignment ("s" or "arr"). + parse_fn: Function for parsing the actual value. + var_type: Type of named variable. + m_dict: Dictionary constructed from regex parsing. + m_dict['val']: RHS value (scalar) + m_dict['index']: List index value (or None) + values: Full expression being parsed + results_dictionary: The dictionary being updated for return by the parsing + function. + + Raises: + ValueError: If the name has already been used. + """ + try: + parsed_value = parse_fn(m_dict['val']) + except ValueError: + _parse_fail(name, var_type, m_dict['val'], values) + + # If no index is provided + if not m_dict['index']: + if name in results_dictionary: + _reuse_fail(name, values) + results_dictionary[name] = parsed_value + else: + if name in results_dictionary: + # The name has already been used as a scalar, then it + # will be in this dictionary and map to a non-dictionary. + if not isinstance(results_dictionary.get(name), dict): + _reuse_fail(name, values) + else: + results_dictionary[name] = {} + + index = int(m_dict['index']) + # Make sure the index position hasn't already been assigned a value. + if index in results_dictionary[name]: + _reuse_fail('{}[{}]'.format(name, index), values) + results_dictionary[name][index] = parsed_value + + +def _process_list_value(name, parse_fn, var_type, m_dict, values, + results_dictionary): + """Update results_dictionary from a list of values. + + Used to update results_dictionary to be returned by parse_values when + encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) + + Mutates results_dictionary. + + Args: + name: Name of variable in assignment ("arr"). + parse_fn: Function for parsing individual values. + var_type: Type of named variable. + m_dict: Dictionary constructed from regex parsing. + m_dict['val']: RHS value (scalar) + values: Full expression being parsed + results_dictionary: The dictionary being updated for return by the parsing + function. + + Raises: + ValueError: If the name has an index or the values cannot be parsed. + """ + if m_dict['index'] is not None: + raise ValueError('Assignment of a list to a list index.') + elements = filter(None, re.split('[ ,]', m_dict['vals'])) + # Make sure the name hasn't already been assigned a value + if name in results_dictionary: + raise _reuse_fail(name, values) + try: + results_dictionary[name] = [parse_fn(e) for e in elements] + except ValueError: + _parse_fail(name, var_type, m_dict['vals'], values) + + +def _cast_to_type_if_compatible(name, param_type, value): + """Cast hparam to the provided type, if compatible. + + Args: + name: Name of the hparam to be cast. + param_type: The type of the hparam. + value: The value to be cast, if compatible. + + Returns: + The result of casting `value` to `param_type`. + + Raises: + ValueError: If the type of `value` is not compatible with param_type. + * If `param_type` is a string type, but `value` is not. + * If `param_type` is a boolean, but `value` is not, or vice versa. + * If `param_type` is an integer type, but `value` is not. + * If `param_type` is a float type, but `value` is not a numeric type. + """ + fail_msg = ( + "Could not cast hparam '%s' of type '%s' from value %r" % + (name, param_type, value)) + + # Some callers use None, for which we can't do any casting/checking. :( + if issubclass(param_type, type(None)): + return value + + # Avoid converting a non-string type to a string. + if (issubclass(param_type, (six.string_types, six.binary_type)) and + not isinstance(value, (six.string_types, six.binary_type))): + raise ValueError(fail_msg) + + # Avoid converting a number or string type to a boolean or vice versa. + if issubclass(param_type, bool) != isinstance(value, bool): + raise ValueError(fail_msg) + + # Avoid converting float to an integer (the reverse is fine). + if (issubclass(param_type, numbers.Integral) and + not isinstance(value, numbers.Integral)): + raise ValueError(fail_msg) + + # Avoid converting a non-numeric type to a numeric type. + if (issubclass(param_type, numbers.Number) and + not isinstance(value, numbers.Number)): + raise ValueError(fail_msg) + + return param_type(value) + + +def parse_values(values, type_map): + """Parses hyperparameter values from a string into a python map. + + `values` is a string containing comma-separated `name=value` pairs. + For each pair, the value of the hyperparameter named `name` is set to + `value`. + + If a hyperparameter name appears multiple times in `values`, a ValueError + is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). + + If a hyperparameter name in both an index assignment and scalar assignment, + a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). + + The hyperparameter name may contain '.' symbols, which will result in an + attribute name that is only accessible through the getattr and setattr + functions. (And must be first explicit added through add_hparam.) + + WARNING: Use of '.' in your variable names is allowed, but is not well + supported and not recommended. + + The `value` in `name=value` must follows the syntax according to the + type of the parameter: + + * Scalar integer: A Python-parsable integer point value. E.g.: 1, + 100, -12. + * Scalar float: A Python-parsable floating point value. E.g.: 1.0, + -.54e89. + * Boolean: Either true or false. + * Scalar string: A non-empty sequence of characters, excluding comma, + spaces, and square brackets. E.g.: foo, bar_1. + * List: A comma separated list of scalar values of the parameter type + enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. + + When index assignment is used, the corresponding type_map key should be the + list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not + "arr[1]"). + + Args: + values: String. Comma separated list of `name=value` pairs where + 'value' must follow the syntax described above. + type_map: A dictionary mapping hyperparameter names to types. Note every + parameter name in values must be a key in type_map. The values must + conform to the types indicated, where a value V is said to conform to a + type T if either V has type T, or V is a list of elements of type T. + Hence, for a multidimensional parameter 'x' taking float values, + 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. + + Returns: + A python map mapping each name to either: + * A scalar value. + * A list of scalar values. + * A dictionary mapping index numbers to scalar values. + (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") + + Raises: + ValueError: If there is a problem with input. + * If `values` cannot be parsed. + * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). + * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', + 'a[1]=1,a[1]=2', or 'a=1,a=[1]') + """ + results_dictionary = {} + pos = 0 + while pos < len(values): + m = PARAM_RE.match(values, pos) + if not m: + raise ValueError('Malformed hyperparameter value: %s' % values[pos:]) + # Check that there is a comma between parameters and move past it. + pos = m.end() + # Parse the values. + m_dict = m.groupdict() + name = m_dict['name'] + if name not in type_map: + raise ValueError('Unknown hyperparameter type for %s' % name) + type_ = type_map[name] + + # Set up correct parsing function (depending on whether type_ is a bool) + if type_ == bool: + + def parse_bool(value): + if value in ['true', 'True']: + return True + elif value in ['false', 'False']: + return False + else: + try: + return bool(int(value)) + except ValueError: + _parse_fail(name, type_, value, values) + + parse = parse_bool + else: + parse = type_ + + # If a singe value is provided + if m_dict['val'] is not None: + _process_scalar_value(name, parse, type_, m_dict, values, + results_dictionary) + + # If the assigned value is a list: + elif m_dict['vals'] is not None: + _process_list_value(name, parse, type_, m_dict, values, + results_dictionary) + + else: # Not assigned a list or value + _parse_fail(name, type_, '', values) + + return results_dictionary + + +class HParams(object): + """Class to hold a set of hyperparameters as name-value pairs. + + A `HParams` object holds hyperparameters used to build and train a model, + such as the number of hidden units in a neural net layer or the learning rate + to use when training. + + You first create a `HParams` object by specifying the names and values of the + hyperparameters. + + To make them easily accessible the parameter names are added as direct + attributes of the class. A typical usage is as follows: + + ```python + # Create a HParams object specifying names and values of the model + # hyperparameters: + hparams = HParams(learning_rate=0.1, num_hidden_units=100) + + # The hyperparameter are available as attributes of the HParams object: + hparams.learning_rate ==> 0.1 + hparams.num_hidden_units ==> 100 + ``` + + Hyperparameters have type, which is inferred from the type of their value + passed at construction type. The currently supported types are: integer, + float, boolean, string, and list of integer, float, boolean, or string. + + You can override hyperparameter values by calling the + [`parse()`](#HParams.parse) method, passing a string of comma separated + `name=value` pairs. This is intended to make it possible to override + any hyperparameter values from a single command-line flag to which + the user passes 'hyper-param=value' pairs. It avoids having to define + one flag for each hyperparameter. + + The syntax expected for each value depends on the type of the parameter. + See `parse()` for a description of the syntax. + + Example: + + ```python + # Define a command line flag to pass name=value pairs. + # For example using argparse: + import argparse + parser = argparse.ArgumentParser(description='Train my model.') + parser.add_argument('--hparams', type=str, + help='Comma separated list of "name=value" pairs.') + args = parser.parse_args() + ... + def my_program(): + # Create a HParams object specifying the names and values of the + # model hyperparameters: + hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, + activations=['relu', 'tanh']) + + # Override hyperparameters values by parsing the command line + hparams.parse(args.hparams) + + # If the user passed `--hparams=learning_rate=0.3` on the command line + # then 'hparams' has the following attributes: + hparams.learning_rate ==> 0.3 + hparams.num_hidden_units ==> 100 + hparams.activations ==> ['relu', 'tanh'] + + # If the hyperparameters are in json format use parse_json: + hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') + ``` + """ + + _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. + + def __init__(self, hparam_def=None, model_structure=None, **kwargs): + """Create an instance of `HParams` from keyword arguments. + + The keyword arguments specify name-values pairs for the hyperparameters. + The parameter types are inferred from the type of the values passed. + + The parameter names are added as attributes of `HParams` object, so they + can be accessed directly with the dot notation `hparams._name_`. + + Example: + + ```python + # Define 3 hyperparameters: 'learning_rate' is a float parameter, + # 'num_hidden_units' an integer parameter, and 'activation' a string + # parameter. + hparams = tf.HParams( + learning_rate=0.1, num_hidden_units=100, activation='relu') + + hparams.activation ==> 'relu' + ``` + + Note that a few names are reserved and cannot be used as hyperparameter + names. If you use one of the reserved name the constructor raises a + `ValueError`. + + Args: + hparam_def: Serialized hyperparameters, encoded as a hparam_pb2.HParamDef + protocol buffer. If provided, this object is initialized by + deserializing hparam_def. Otherwise **kwargs is used. + model_structure: An instance of ModelStructure, defining the feature + crosses to be used in the Trial. + **kwargs: Key-value pairs where the key is the hyperparameter name and + the value is the value for the parameter. + + Raises: + ValueError: If both `hparam_def` and initialization values are provided, + or if one of the arguments is invalid. + + """ + # Register the hyperparameters and their type in _hparam_types. + # This simplifies the implementation of parse(). + # _hparam_types maps the parameter name to a tuple (type, bool). + # The type value is the type of the parameter for scalar hyperparameters, + # or the type of the list elements for multidimensional hyperparameters. + # The bool value is True if the value is a list, False otherwise. + self._hparam_types = {} + self._model_structure = model_structure + if hparam_def: +## self._init_from_proto(hparam_def) +## if kwargs: +## raise ValueError('hparam_def and initialization values are ' +## 'mutually exclusive') + raise ValueError('hparam_def has been disabled in this version') + else: + for name, value in six.iteritems(kwargs): + self.add_hparam(name, value) + +## def _init_from_proto(self, hparam_def): +## """Creates a new HParams from `HParamDef` protocol buffer. +## +## Args: +## hparam_def: `HParamDef` protocol buffer. +## """ +## assert isinstance(hparam_def, hparam_pb2.HParamDef) +## for name, value in hparam_def.hparam.items(): +## kind = value.WhichOneof('kind') +## if kind.endswith('_value'): +## # Single value. +## if kind.startswith('int64'): +## # Setting attribute value to be 'int' to ensure the type is compatible +## # with both Python2 and Python3. +## self.add_hparam(name, int(getattr(value, kind))) +## elif kind.startswith('bytes'): +## # Setting attribute value to be 'str' to ensure the type is compatible +## # with both Python2 and Python3. UTF-8 encoding is assumed. +## self.add_hparam(name, compat.as_str(getattr(value, kind))) +## else: +## self.add_hparam(name, getattr(value, kind)) +## else: +## # List of values. +## if kind.startswith('int64'): +## # Setting attribute value to be 'int' to ensure the type is compatible +## # with both Python2 and Python3. +## self.add_hparam(name, [int(v) for v in getattr(value, kind).value]) +## elif kind.startswith('bytes'): +## # Setting attribute value to be 'str' to ensure the type is compatible +## # with both Python2 and Python3. UTF-8 encoding is assumed. +## self.add_hparam( +## name, [compat.as_str(v) for v in getattr(value, kind).value]) +## else: +## self.add_hparam(name, [v for v in getattr(value, kind).value]) + + def add_hparam(self, name, value): + """Adds {name, value} pair to hyperparameters. + + Args: + name: Name of the hyperparameter. + value: Value of the hyperparameter. Can be one of the following types: + int, float, string, int list, float list, or string list. + + Raises: + ValueError: if one of the arguments is invalid. + """ + # Keys in kwargs are unique, but 'name' could the name of a pre-existing + # attribute of this object. In that case we refuse to use it as a + # hyperparameter name. + if getattr(self, name, None) is not None: + raise ValueError('Hyperparameter name is reserved: %s' % name) + if isinstance(value, (list, tuple)): + if not value: + raise ValueError( + 'Multi-valued hyperparameters cannot be empty: %s' % name) + self._hparam_types[name] = (type(value[0]), True) + else: + self._hparam_types[name] = (type(value), False) + setattr(self, name, value) + + def set_hparam(self, name, value): + """Set the value of an existing hyperparameter. + + This function verifies that the type of the value matches the type of the + existing hyperparameter. + + Args: + name: Name of the hyperparameter. + value: New value of the hyperparameter. + + Raises: + ValueError: If there is a type mismatch. + """ + param_type, is_list = self._hparam_types[name] + if isinstance(value, list): + if not is_list: + raise ValueError( + 'Must not pass a list for single-valued parameter: %s' % name) + setattr(self, name, [ + _cast_to_type_if_compatible(name, param_type, v) for v in value]) + else: + if is_list: + raise ValueError( + 'Must pass a list for multi-valued parameter: %s.' % name) + setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) + + def del_hparam(self, name): + """Removes the hyperparameter with key 'name'. + + Args: + name: Name of the hyperparameter. + """ + if hasattr(self, name): + delattr(self, name) + del self._hparam_types[name] + + def parse(self, values): + """Override hyperparameter values, parsing new values from a string. + + See parse_values for more detail on the allowed format for values. + + Args: + values: String. Comma separated list of `name=value` pairs where + 'value' must follow the syntax described above. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values` cannot be parsed. + """ + type_map = dict() + for name, t in self._hparam_types.items(): + param_type, _ = t + type_map[name] = param_type + + values_map = parse_values(values, type_map) + return self.override_from_dict(values_map) + + def override_from_dict(self, values_dict): + """Override hyperparameter values, parsing new values from a dictionary. + + Args: + values_dict: Dictionary of name:value pairs. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values_dict` cannot be parsed. + """ + for name, value in values_dict.items(): + self.set_hparam(name, value) + return self + +## @deprecation.deprecated(None, 'Use `override_from_dict`.') + def set_from_map(self, values_map): + """DEPRECATED. Use override_from_dict.""" + return self.override_from_dict(values_dict=values_map) + + def set_model_structure(self, model_structure): + self._model_structure = model_structure + + def get_model_structure(self): + return self._model_structure + + def to_json(self, indent=None, separators=None, sort_keys=False): + """Serializes the hyperparameters into JSON. + + Args: + indent: If a non-negative integer, JSON array elements and object members + will be pretty-printed with that indent level. An indent level of 0, or + negative, will only insert newlines. `None` (the default) selects the + most compact representation. + separators: Optional `(item_separator, key_separator)` tuple. Default is + `(', ', ': ')`. + sort_keys: If `True`, the output dictionaries will be sorted by key. + + Returns: + A JSON string. + """ + return json.dumps( + self.values(), + indent=indent, + separators=separators, + sort_keys=sort_keys) + + def parse_json(self, values_json): + """Override hyperparameter values, parsing new values from a json object. + + Args: + values_json: String containing a json object of name:value pairs. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values_json` cannot be parsed. + """ + values_map = json.loads(values_json) + return self.override_from_dict(values_map) + + def values(self): + """Return the hyperparameter values as a Python dictionary. + + Returns: + A dictionary with hyperparameter names as keys. The values are the + hyperparameter values. + """ + return {n: getattr(self, n) for n in self._hparam_types.keys()} + + def get(self, key, default=None): + """Returns the value of `key` if it exists, else `default`.""" + if key in self._hparam_types: + # Ensure that default is compatible with the parameter type. + if default is not None: + param_type, is_param_list = self._hparam_types[key] + type_str = 'list<%s>' % param_type if is_param_list else str(param_type) + fail_msg = ("Hparam '%s' of type '%s' is incompatible with " + 'default=%s' % (key, type_str, default)) + + is_default_list = isinstance(default, list) + if is_param_list != is_default_list: + raise ValueError(fail_msg) + + try: + if is_default_list: + for value in default: + _cast_to_type_if_compatible(key, param_type, value) + else: + _cast_to_type_if_compatible(key, param_type, default) + except ValueError as e: + raise ValueError('%s. %s' % (fail_msg, e)) + + return getattr(self, key) + + return default + + def __contains__(self, key): + return key in self._hparam_types + + def __str__(self): + return str(sorted(self.values().items())) + + def __repr__(self): + return '%s(%s)' % (type(self).__name__, self.__str__()) + + @staticmethod + def _get_kind_name(param_type, is_list): + """Returns the field name given parameter type and is_list. + + Args: + param_type: Data type of the hparam. + is_list: Whether this is a list. + + Returns: + A string representation of the field name. + + Raises: + ValueError: If parameter type is not recognized. + """ + if issubclass(param_type, bool): + # This check must happen before issubclass(param_type, six.integer_types), + # since Python considers bool to be a subclass of int. + typename = 'bool' + elif issubclass(param_type, six.integer_types): + # Setting 'int' and 'long' types to be 'int64' to ensure the type is + # compatible with both Python2 and Python3. + typename = 'int64' + elif issubclass(param_type, (six.string_types, six.binary_type)): + # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is + # compatible with both Python2 and Python3. + typename = 'bytes' + elif issubclass(param_type, float): + typename = 'float' + else: + raise ValueError('Unsupported parameter type: %s' % str(param_type)) + + suffix = 'list' if is_list else 'value' + return '_'.join([typename, suffix]) + +## def to_proto(self, export_scope=None): # pylint: disable=unused-argument +## """Converts a `HParams` object to a `HParamDef` protocol buffer. +## +## Args: +## export_scope: Optional `string`. Name scope to remove. +## +## Returns: +## A `HParamDef` protocol buffer. +## """ +## hparam_proto = hparam_pb2.HParamDef() +## for name in self._hparam_types: +## # Parse the values. +## param_type, is_list = self._hparam_types.get(name, (None, None)) +## kind = HParams._get_kind_name(param_type, is_list) +## +## if is_list: +## if kind.startswith('bytes'): +## v_list = [compat.as_bytes(v) for v in getattr(self, name)] +## else: +## v_list = [v for v in getattr(self, name)] +## getattr(hparam_proto.hparam[name], kind).value.extend(v_list) +## else: +## v = getattr(self, name) +## if kind.startswith('bytes'): +## v = compat.as_bytes(getattr(self, name)) +## setattr(hparam_proto.hparam[name], kind, v) +## +## return hparam_proto + +## @staticmethod +## def from_proto(hparam_def, import_scope=None): # pylint: disable=unused-argument +## return HParams(hparam_def=hparam_def) + + +## ops.register_proto_function( +## 'hparams', +## proto_type=hparam_pb2.HParamDef, +## to_proto=HParams.to_proto, +## from_proto=HParams.from_proto) diff --git a/HuaWeiExperiment/wavenet/wavenet_vocoder/tfcompat/readme.md b/HuaWeiExperiment/wavenet/wavenet_vocoder/tfcompat/readme.md new file mode 100644 index 00000000..3d94e4c4 --- /dev/null +++ b/HuaWeiExperiment/wavenet/wavenet_vocoder/tfcompat/readme.md @@ -0,0 +1,8 @@ +Source: hparam.py copied from tensorflow v1.12.0. + +https://github.com/tensorflow/tensorflow/blob/v1.12.0/tensorflow/contrib/training/python/training/hparam.py + +with the following: +wget https://github.com/tensorflow/tensorflow/raw/v1.12.0/tensorflow/contrib/training/python/training/hparam.py + +Once all other tensorflow dependencies of these file are removed, the class keeps its goal. Functions not available due to this process are not used in this project. diff --git a/HuaWeiExperiment/wavenet/wavenet_vocoder/upsample.py b/HuaWeiExperiment/wavenet/wavenet_vocoder/upsample.py new file mode 100644 index 00000000..32b4ba15 --- /dev/null +++ b/HuaWeiExperiment/wavenet/wavenet_vocoder/upsample.py @@ -0,0 +1,111 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Upsampling. +""" +from __future__ import with_statement, print_function, absolute_import +import numpy as np +from mindspore import nn +from mindspore.ops import operations as P + + +class Resize(nn.Cell): + """ + Resize input Tensor + """ + + def __init__(self, x_scale, y_scale, mode="nearest"): + super(Resize, self).__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def construct(self, x): + _, _, h, w = x.shape + interpolate_op = P.ResizeNearestNeighbor((self.y_scale * h, self.x_scale * w)) + return interpolate_op(x) + + +def _get_activation(upsample_activation): + """get activation""" + nonlinear = getattr(nn, upsample_activation) + return nonlinear + + +class UpsampleNetwork(nn.Cell): + """UpsampleNetwork""" + def __init__(self, upsample_scales, mode="nearest", + freq_axis_kernel_size=1, cin_pad=0, cin_channels=80): + super(UpsampleNetwork, self).__init__() + self.expand_op = P.ExpandDims() + self.squeeze_op = P.Squeeze(1) + up_layers = [] + total_scale = np.prod(upsample_scales) + self.indent = cin_pad * total_scale + for scale in upsample_scales: + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + k_size = (freq_axis_kernel_size, scale * 2 + 1) + padding = (freq_axis_padding, freq_axis_padding, scale, scale) + stretch = Resize(scale, 1, mode) + conv = nn.Conv2d(1, 1, kernel_size=k_size, has_bias=False, pad_mode='pad', padding=padding) + up_layers.append(stretch) + up_layers.append(conv) + self.up_layers = nn.CellList(up_layers) + + def construct(self, c): + """ + + Args: + c (Tensor): Local conditioning feature + + Returns: + Tensor: Upsampling feature + + """ + # B x 1 x C x T + c = self.expand_op(c, 1) + for f in self.up_layers: + c = f(c) + # B x C x T + c = self.squeeze_op(c) + + return c + + +class ConvInUpsampleNetwork(nn.Cell): + """Upsample Network + + Args: + upsample_scales (list): Upsample_scales list. + upsample_activation (str): Upsample_activation. + mode (str): Resize mode, default is NearestNeighbor. + cin_channels (int): Local conditioning channels. + freq_axis_kernel_size (int): Freq-axis kernel_size for the convolution layers after resize. + + """ + + def __init__(self, upsample_scales, mode="nearest", + freq_axis_kernel_size=1, cin_pad=0, + cin_channels=80): + super(ConvInUpsampleNetwork, self).__init__() + ks = 2 * cin_pad + 1 + self.conv_in = nn.Conv1d(cin_channels, cin_channels, kernel_size=ks, has_bias=False, pad_mode='pad', padding=0) + self.upsample = UpsampleNetwork(upsample_scales, mode, freq_axis_kernel_size, cin_pad=0, + cin_channels=cin_channels) + + def construct(self, c): + c = self.conv_in(c) + c_up = self.upsample(c) + return c_up diff --git a/HuaWeiExperiment/wavenet/wavenet_vocoder/util.py b/HuaWeiExperiment/wavenet/wavenet_vocoder/util.py new file mode 100644 index 00000000..4fea1d98 --- /dev/null +++ b/HuaWeiExperiment/wavenet/wavenet_vocoder/util.py @@ -0,0 +1,25 @@ +# coding: utf-8 +from __future__ import with_statement, print_function, absolute_import + + +def _assert_valid_input_type(s): + assert s == "mulaw-quantize" or s == "mulaw" or s == "raw" + + +def is_mulaw_quantize(s): + _assert_valid_input_type(s) + return s == "mulaw-quantize" + + +def is_mulaw(s): + _assert_valid_input_type(s) + return s == "mulaw" + + +def is_raw(s): + _assert_valid_input_type(s) + return s == "raw" + + +def is_scalar_input(s): + return is_raw(s) or is_mulaw(s) diff --git a/HuaWeiExperiment/wavenet/wavenet_vocoder/wavenet.py b/HuaWeiExperiment/wavenet/wavenet_vocoder/wavenet.py new file mode 100644 index 00000000..50ca5b03 --- /dev/null +++ b/HuaWeiExperiment/wavenet/wavenet_vocoder/wavenet.py @@ -0,0 +1,335 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +WaveNet construction. +""" +from __future__ import with_statement, print_function, absolute_import + +import math +import numpy as np + +from mindspore import nn, Tensor +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from wavenet_vocoder import upsample +from .modules import Embedding +from .modules import Conv1d1x1 +from .modules import ResidualConv1dGLU +from .mixture import sample_from_discretized_mix_logistic +from .mixture import sample_from_mix_gaussian +from .mixture import sample_from_mix_onehotcategorical + + +class WaveNet(nn.Cell): + """ + WaveNet model definition. Only local condition is supported + + Args: + out_channels (int): Output channels. If input_type is mu-law quantized one-hot vecror, it should equal to the + quantize channels. Otherwise, it equals to num_mixtures x 3. Default: 256. + layers (int): Number of ResidualConv1dGLU layers + stacks (int): Number of dilation cycles + residual_channels (int): Residual input / output channels + gate_channels (int): Gated activation channels. + skip_out_channels (int): Skip connection channels. + kernel_size (int): Kernel size . + dropout (float): Dropout rate. + cin_channels (int): Local conditioning channels. If given negative value, local conditioning is disabled. + gin_channels (int): Global conditioning channels. If given negative value, global conditioning is disabled. + n_speakers (int): Number of speakers. This is used when global conditioning is enabled. + upsample_conditional_features (bool): Whether upsampling local conditioning features by resize_nearestneighbor + and conv or not. + scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise, quantized one-hot vector + is expected. + use_speaker_embedding (Bool): Use speaker embedding or Not. + + """ + + def __init__(self, out_channels=256, layers=20, stacks=2, + residual_channels=512, + gate_channels=512, + skip_out_channels=512, + kernel_size=3, dropout=1 - 0.95, + cin_channels=-1, gin_channels=-1, n_speakers=None, + upsample_conditional_features=False, + upsample_net="ConvInUpsampleNetwork", + upsample_params=None, + scalar_input=False, + use_speaker_embedding=False, + output_distribution="Logistic", + cin_pad=0, + ): + super(WaveNet, self).__init__() + self.transpose_op = P.Transpose() + self.softmax = P.Softmax(axis=1) + self.reshape_op = P.Reshape() + self.zeros_op = P.Zeros() + self.ones_op = P.Ones() + self.squeeze_op = P.Squeeze() + self.expandim_op = P.ExpandDims() + self.transpose_op = P.Transpose() + self.tile_op = P.Tile() + self.scalar_input = scalar_input + self.out_channels = out_channels + self.cin_channels = cin_channels + self.output_distribution = output_distribution + self.fack_data = P.Zeros() + assert layers % stacks == 0 + layers_per_stack = layers // stacks # 24 / 4 = 6 + if scalar_input: + self.first_conv = Conv1d1x1(1, residual_channels) + else: + self.first_conv = Conv1d1x1(out_channels, residual_channels) + + conv_layers = [] + for layer in range(layers): + dilation = 2 ** (layer % layers_per_stack) # 1, 2, 4, 8, 16, 32 + conv = ResidualConv1dGLU( + residual_channels, gate_channels, + kernel_size=kernel_size, + skip_out_channels=skip_out_channels, + bias=True, + dropout=dropout, + dilation=dilation, + cin_channels=cin_channels, + gin_channels=gin_channels) + conv_layers.append(conv) + self.conv_layers = nn.CellList(conv_layers) + self.last_conv_layers = nn.CellList([ + nn.ReLU(), + Conv1d1x1(skip_out_channels, skip_out_channels), + nn.ReLU(), + Conv1d1x1(skip_out_channels, out_channels)]) + + if gin_channels > 0 and use_speaker_embedding: + assert n_speakers is not None + self.embed_speakers = Embedding( + n_speakers, gin_channels, padding_idx=None, std=0.1) + else: + self.embed_speakers = None + + if upsample_conditional_features: + self.upsample_net = getattr(upsample, upsample_net)(**upsample_params) + else: + self.upsample_net = None + + self.factor = math.sqrt(1.0 / len(self.conv_layers)) # sqrt( 1 / 24) + + def _expand_global_features(self, batch_size, time_step, g_fp, is_expand=True): + """Expand global conditioning features to all time steps + + Args: + batch_size (int): Batch size. + time_step (int): Time length. + g_fp (Tensor): Global features, (B x C) or (B x C x 1). + is_expand (bool) : Expanded global conditioning features + + Returns: + Tensor: B x C x T or B x T x C or None + """ + if g_fp is None: + return None + if len(g_fp.shape) == 2: + g_fp = self.expandim_op(g_fp, -1) + else: + g_fp = g_fp + + if is_expand: + expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step)) + return expand_fp + expand_fp = self.tile_op(g_fp, (batch_size, 1, time_step)) + expand_fp = self.transpose_op(expand_fp, (0, 2, 1)) + return expand_fp + + def construct(self, x, cond=None, g=None, softmax=False): + """ + + Args: + x (Tensor): One-hot encoded audio signal + c (Tensor): Local conditioning feature + g (Tensor): Global conditioning feature + softmax (bool): Whether use softmax or not + + Returns: + Tensor: Net output + + """ + + g = None + B, _, T = x.shape + if g is not None: + if self.embed_speakers is not None: + g = self.embed_speakers(self.reshape_op(g, (B, -1))) + g = self.transpose_op(g, (0, 2, 1)) + g_bct = self._expand_global_features(B, T, g, is_expand=True) # None + + if cond is not None and self.upsample_net is not None: + cond = self.upsample_net(cond) # [B, 128, 10240] + + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, hidden = f(x, cond, g_bct) # x=[B, 128, 10240], hidden=[B, 128, 10240] + skips += hidden + skips *= self.factor + + x = skips # x=[B, 128, 10240] + for f in self.last_conv_layers: + x = f(x) # x=[B, 2, 10240] + x = self.softmax(x) if softmax else x + + return x + + def relu_numpy(self, inX): + """numpy relu function""" + return np.maximum(0, inX) + + def softmax_numpy(self, x): + """ numpy softmax function """ + x -= np.max(x, axis=1, keepdims=True) + return np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) + + def incremental_forward(self, initial_input=None, c_=None, g=None, + T=100, test_inputs=None, + tqdm=lambda x: x, softmax=True, quantize=True, + log_scale_min=-50.0, is_numpy=True): + """ + Incremental forward. Current output depends on last output. + + Args: + initial_input (Tensor): Initial input, the shape is B x C x 1 + c (Tensor): Local conditioning feature, the shape is B x C x T + g (Tensor): Global conditioning feature, the shape is B x C or B x C x 1 + T (int): decoding time step. + test_inputs: Teacher forcing inputs (for debugging) + tqdm (lamda): tqmd + softmax (bool): Whether use softmax or not + quantize (bool): Whether quantize softmax output in last step when decoding current step + log_scale_min (float): Log scale minimum value + + Returns: + Tensor: Predicted on-hot encoded samples or scalar vector depending on loss type + + """ + + self.clear_buffer() + B = 1 + + if test_inputs is not None: + if self.scalar_input: + if test_inputs.shape[1] == 1: + test_inputs = self.transpose_op(test_inputs, (0, 2, 1)) + else: + if test_inputs.shape[1] == self.out_channels: + test_inputs = self.transpose_op(test_inputs, (0, 2, 1)) + + B = test_inputs.shape[0] + if T is None: + T = test_inputs.shape[1] + else: + T = max(T, test_inputs.shape[1]) + T = int(T) + + # Global conditioning + if g is not None: + if self.embed_speakers is not None: + g = self.embed_speakers(self.reshape_op(g, (B, -1))) + g = self.transpose_op(g, (0, 2, 1)) + assert g.dim() == 3 + g_btc = self._expand_global_features(B, T, g, is_expand=False) + + # Local conditioning + if c_ is not None: + B = c_.shape[0] + if self.upsample_net is not None: + c_ = self.upsample_net(c_) + assert c_.shape[-1] == T + if c_.shape[-1] == T: + c_ = self.transpose_op(c_, (0, 2, 1)) + + outputs = [] + if initial_input is None: + if self.scalar_input: + initial_input = self.zeros_op((B, 1, 1), mstype.float32) + else: + initial_input = np.zeros((B, 1, self.out_channels), np.float32) + initial_input[:, :, 127] = 1 + initial_input = Tensor(initial_input) + else: + if initial_input.shape[1] == self.out_channels: + initial_input = self.transpose_op(initial_input, (0, 2, 1)) + + current_input = initial_input.asnumpy() + + for t in tqdm(range(T)): + if test_inputs is not None and t < test_inputs.shape[1]: + current_input = self.expandim_op(test_inputs[:, t, :], 1) + else: + if t > 0: + current_input = outputs[-1] + + # Conditioning features for single time step + ct = None if c_ is None else self.expandim_op(c_[:, t, :], 1) + gt = None if g is None else self.expandim_op(g_btc[:, t, :], 1) + + x = current_input + ct = ct.asnumpy() + x = self.first_conv.incremental_forward(x) + + skips = 0 + for f in self.conv_layers: + x, h = f.incremental_forward(x, ct, gt) + skips += h + skips *= self.factor + x = skips + + for f in self.last_conv_layers: + try: + x = f.incremental_forward(x) + except AttributeError: + x = self.relu_numpy(x) + + # Generate next input by sampling + if self.scalar_input: + if self.output_distribution == "Logistic": + x = sample_from_discretized_mix_logistic(x.reshape((B, -1, 1)), log_scale_min=log_scale_min) + + elif self.output_distribution == "Normal": + x = sample_from_mix_gaussian(x.reshape((B, -1, 1)), log_scale_min=log_scale_min) + else: + assert False + else: + x = self.softmax_numpy(np.reshape(x, (B, -1))) if softmax else np.reshape(x, (B, -1)) + if quantize: + x = sample_from_mix_onehotcategorical(x) + + outputs += [x] + # T x B x C + outputs = np.stack(outputs, 0) + # B x C x T + outputs = np.transpose(outputs, (1, 2, 0)) + self.clear_buffer() + return outputs + + def clear_buffer(self): + """clear buffer""" + self.first_conv.clear_buffer() + for f in self.conv_layers: + f.clear_buffer() + for f in self.last_conv_layers: + try: + f.clear_buffer() + except AttributeError: + pass diff --git "a/HuaWeiExperiment/\350\257\255\351\237\263\350\257\206\345\210\253-\345\215\216\344\270\272\345\256\236\351\252\214-wavenet.docx" "b/HuaWeiExperiment/\350\257\255\351\237\263\350\257\206\345\210\253-\345\215\216\344\270\272\345\256\236\351\252\214-wavenet.docx" new file mode 100644 index 00000000..4f0b5a1a Binary files /dev/null and "b/HuaWeiExperiment/\350\257\255\351\237\263\350\257\206\345\210\253-\345\215\216\344\270\272\345\256\236\351\252\214-wavenet.docx" differ diff --git a/README.md b/README.md index 7cd82328..ef754c3b 100644 --- a/README.md +++ b/README.md @@ -1,210 +1,119 @@ -# talkingface-toolkit -## 框架整体介绍 -### checkpoints -主要保存的是训练和评估模型所需要的额外的预训练模型,在对应文件夹的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/checkpoints/README.md)有更详细的介绍 - -### datset -存放数据集以及数据集预处理之后的数据,详细内容见dataset里的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/dataset/README.md) - -### saved -存放训练过程中保存的模型checkpoint, 训练过程中保存模型时自动创建 - -### talkingface -主要功能模块,包括所有核心代码 - -#### config -根据模型和数据集名称自动生成所有模型、数据集、训练、评估等相关的配置信息 -``` -config/ - -├── configurator.py - -``` -#### data -- dataprocess:模型特有的数据处理代码,(可以是对方仓库自己实现的音频特征提取、推理时的数据处理)。如果实现的模型有这个需求,就要建立一对应的文件 -- dataset:每个模型都要重载`torch.utils.data.Dataset` 用于加载数据。每个模型都要有一个`model_name+'_dataset.py'`文件. `__getitem__()`方法的返回值应处理成字典类型的数据。 (核心部分) -``` -data/ - -├── dataprocess - -| ├── wav2lip_process.py - -| ├── xxxx_process.py - -├── dataset - -| ├── wav2lip_dataset.py - -| ├── xxx_dataset.py -``` - -#### evaluate -主要涉及模型评估的代码 -LSE metric 需要的数据是生成的视频列表 -SSIM metric 需要的数据是生成的视频和真实的视频列表 - -#### model -实现的模型的网络和对应的方法 (核心部分) - -主要分三类: -- audio-driven (音频驱动) -- image-driven (图像驱动) -- nerf-based (基于神经辐射场的方法) - -``` -model/ - -├── audio_driven_talkingface - -| ├── wav2lip.py - -├── image_driven_talkingface - -| ├── xxxx.py - -├── nerf_based_talkingface - -| ├── xxxx.py - -├── abstract_talkingface.py - -``` - -#### properties -保存默认配置文件,包括: -- 数据集配置文件 -- 模型配置文件 -- 通用配置文件 - -需要根据对应模型和数据集增加对应的配置文件,通用配置文件`overall.yaml`一般不做修改 -``` -properties/ - -├── dataset - -| ├── xxx.yaml - -├── model - -| ├── xxx.yaml - -├── overall.yaml - -``` - -#### quick_start -通用的启动文件,根据传入参数自动配置数据集和模型,然后训练和评估(一般不需要修改) -``` -quick_start/ - -├── quick_start.py - -``` - -#### trainer -训练、评估函数的主类。在trainer中,如果可以使用基类`Trainer`实现所有功能,则不需要写一个新的。如果模型训练有一些特有部分,则需要重载`Trainer`。需要重载部分可能主要集中于: `_train_epoch()`, `_valid_epoch()`。 重载的`Trainer`应该命名为:`{model_name}Trainer` -``` -trainer/ - -├── trainer.py - -``` - -#### utils -公用的工具类,包括`s3fd`人脸检测,视频抽帧、视频抽音频方法。还包括根据参数配置找对应的模型类、数据类等方法。 -一般不需要修改,但可以适当添加一些必须的且相对普遍的数据处理文件。 - -## 使用方法 -### 环境要求 -- `python=3.8` -- `torch==1.13.1+cu116`(gpu版,若设备不支持cuda可以使用cpu版) -- `numpy==1.20.3` -- `librosa==0.10.1` - -尽量保证上面几个包的版本一致 - -提供了两种配置其他环境的方法: -``` -pip install -r requirements.txt - -or - -conda env create -f environment.yml -``` - -建议使用conda虚拟环境!!! - -### 训练和评估 - -```bash -python run_talkingface.py --model=xxxx --dataset=xxxx (--other_parameters=xxxxxx) -``` - -### 权重文件 - -- LSE评估需要的权重: syncnet_v2.model [百度网盘下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc) -- wav2lip需要的lip expert 权重:lipsync_expert.pth [百度网下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc) - -## 可选论文: -### Aduio_driven talkingface -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| MakeItTalk | [paper](https://arxiv.org/abs/2004.12992) | [code](https://github.com/yzhou359/MakeItTalk) | -| MEAD | [paper](https://wywu.github.io/projects/MEAD/support/MEAD.pdf) | [code](https://github.com/uniBruce/Mead) | -| RhythmicHead | [paper](https://arxiv.org/pdf/2007.08547v1.pdf) | [code](https://github.com/lelechen63/Talking-head-Generation-with-Rhythmic-Head-Motion) | -| PC-AVS | [paper](https://arxiv.org/abs/2104.11116) | [code](https://github.com/Hangz-nju-cuhk/Talking-Face_PC-AVS) | -| EVP | [paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Ji_Audio-Driven_Emotional_Video_Portraits_CVPR_2021_paper.pdf) | [code](https://github.com/jixinya/EVP) | -| LSP | [paper](https://arxiv.org/abs/2109.10595) | [code](https://github.com/YuanxunLu/LiveSpeechPortraits) | -| EAMM | [paper](https://arxiv.org/pdf/2205.15278.pdf) | [code](https://github.com/jixinya/EAMM/) | -| DiffTalk | [paper](https://arxiv.org/abs/2301.03786) | [code](https://github.com/sstzal/DiffTalk) | -| TalkLip | [paper](https://arxiv.org/pdf/2303.17480.pdf) | [code](https://github.com/Sxjdwang/TalkLip) | -| EmoGen | [paper](https://arxiv.org/pdf/2303.11548.pdf) | [code](https://github.com/sahilg06/EmoGen) | -| SadTalker | [paper](https://arxiv.org/abs/2211.12194) | [code](https://github.com/OpenTalker/SadTalker) | -| HyperLips | [paper](https://arxiv.org/abs/2310.05720) | [code](https://github.com/semchan/HyperLips) | -| PHADTF | [paper](http://arxiv.org/abs/2002.10137) | [code](https://github.com/yiranran/Audio-driven-TalkingFace-HeadPose) | -| VideoReTalking | [paper](https://arxiv.org/abs/2211.14758) | [code](https://github.com/OpenTalker/video-retalking#videoretalking--audio-based-lip-synchronization-for-talking-head-video-editing-in-the-wild-) -| | - - - -### Image_driven talkingface -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| PIRenderer | [paper](https://arxiv.org/pdf/2109.08379.pdf) | [code](https://github.com/RenYurui/PIRender) | -| StyleHEAT | [paper](https://arxiv.org/pdf/2203.04036.pdf) | [code](https://github.com/OpenTalker/StyleHEAT) | -| MetaPortrait | [paper](https://arxiv.org/abs/2212.08062) | [code](https://github.com/Meta-Portrait/MetaPortrait) | -| | -### Nerf-based talkingface -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| AD-NeRF | [paper](https://arxiv.org/abs/2103.11078) | [code](https://github.com/YudongGuo/AD-NeRF) | -| GeneFace | [paper](https://arxiv.org/abs/2301.13430) | [code](https://github.com/yerfor/GeneFace) | -| DFRF | [paper](https://arxiv.org/abs/2207.11770) | [code](https://github.com/sstzal/DFRF) | -| | -### text_to_speech -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| VITS | [paper](https://arxiv.org/abs/2106.06103) | [code](https://github.com/jaywalnut310/vits) | -| Glow TTS | [paper](https://arxiv.org/abs/2005.11129) | [code](https://github.com/jaywalnut310/glow-tts) | -| FastSpeech2 | [paper](https://arxiv.org/abs/2006.04558v1) | [code](https://github.com/ming024/FastSpeech2) | -| StyleTTS2 | [paper](https://arxiv.org/abs/2306.07691) | [code](https://github.com/yl4579/StyleTTS2) | -| Grad-TTS | [paper](https://arxiv.org/abs/2105.06337) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS) | -| FastSpeech | [paper](https://arxiv.org/abs/1905.09263) | [code](https://github.com/xcmyz/FastSpeech) | -| | -### voice_conversion -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| StarGAN-VC | [paper](http://www.kecl.ntt.co.jp/people/kameoka.hirokazu/Demos/stargan-vc2/index.html) | [code](https://github.com/kamepong/StarGAN-VC) | -| Emo-StarGAN | [paper](https://www.researchgate.net/publication/373161292_Emo-StarGAN_A_Semi-Supervised_Any-to-Many_Non-Parallel_Emotion-Preserving_Voice_Conversion) | [code](https://github.com/suhitaghosh10/emo-stargan) | -| adaptive-VC | [paper](https://arxiv.org/abs/1904.05742) | [code](https://github.com/jjery2243542/adaptive_voice_conversion) | -| DiffVC | [paper](https://arxiv.org/abs/2109.13821) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC) | -| Assem-VC | [paper](https://arxiv.org/abs/2104.00931) | [code](https://github.com/maum-ai/assem-vc) | -| | - -## 作业要求 -- 确保可以仅在命令行输入模型和数据集名称就可以训练、验证。(部分仓库没有提供训练代码的,可以不训练) -- 每个组都要提交一个README文件,写明完成的功能、最终实现的训练、验证截图、所使用的依赖、成员分工等。 - - +论文复现-LiveSpeechPortraits + +**详情请查看文档"语音识别-论文复现-LiveSpeechPortraits.docx"以及“演示视频.mkv”** + +**详情请查看文档"语音识别-论文复现-LiveSpeechPortraits.docx"以及“演示视频.mkv”** + +**详情请查看文档"语音识别-论文复现-LiveSpeechPortraits.docx"以及“演示视频.mkv”** + +选取的论文:LiveSpeechPortraits + +实现的功能:输入指定音频(这里音频路径在/properties/model/live_speech_portraits.yaml中指定),根据预训练的模型(每个模型都是因人而定的)生成视频,视频内容为对应人物说话,且口型与音频一致。 + +训练:无,因为论文作者并没有提供训练代码和对应的数据集,据作者所言此论文为他在实习期间完成,其训练代码及其相关内容的版权均属于其公司,不能对外分享,但是提供了生成的模型以及部分代码以供验证。 + +下面两张图为缺失训练的原因: + +![缺失训练的原因](./缺失训练的原因.png) + +![缺失训练的原因-2](./缺失训练的原因-2.png) + +运行: + +在命令行中输入如下命令,即可运行。 + +python run_talkingface.py --model=live_speech_portraits --dataset=live_speech_portraits --evaluate_model_file=notNone + +命令行截图如下所示,其他截图请查看**文档"语音识别-论文复现-LiveSpeechPortraits.docx"以及“演示视频.mkv”**: + +![命令行截图](./命令行截图.png) + +下图为验证的截图,详情查看**文档"语音识别-论文复现-LiveSpeechPortraits.docx"以及“演示视频.mkv”** + +验证截图:![验证截图](./验证截图.png) + +验证说明:如果需要验证其他音频或其他说话人,请修改talkingface/properties/model处的live_speech_portraits.yaml,除model_params处的APC的ckp_path无需修改外,其他路径请都修改为你所需要指定的音频和说话人。 + +注意:验证所需的模型和数据请从此处[data - Google 云端硬盘](https://drive.google.com/drive/folders/1sHc2xEEGwnb0h2rkUhG9sPmOxvRvPVpJ)下载,并且放在checkpoints/live_speech_portraits目录下。 + +所使用的依赖: + +absl-py==2.1.0 +aiosignal==1.3.1 +albumentations==0.5.2 +anyio==4.2.0 +attrs==23.2.0 +cachetools==5.3.2 +click==8.1.7 +cog==0.9.4 +colorlog==6.7.0 +dominate==2.9.1 +exceptiongroup==1.2.0 +fastapi==0.98.0 +filelock==3.13.1 +fonttools==4.25.0 +frozenlist==1.4.1 +google-auth==2.26.2 +google-auth-oauthlib==0.4.6 +grpcio==1.60.0 +h11==0.14.0 +h5py==3.10.0 +httptools==0.6.1 +imageio==2.33.1 +imgaug==0.4.0 +Jinja2==3.1.3 +jsonschema==4.21.0 +jsonschema-specifications==2023.12.1 +librosa==0.7.0 +llvmlite==0.31.0 +Markdown==3.5.2 +MarkupSafe==2.1.4 +mkl-service==2.4.0 +mpmath==1.3.0 +munkres==1.1.4 +networkx==3.1 +numba==0.48.0 +numpy==1.20.3 +oauthlib==3.2.2 +opencv-python==3.4.9.33 +opencv-python-headless==4.9.0.80 +pandas==1.3.4 +pkgutil_resolve_name==1.3.10 +protobuf==3.19.0 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pydantic==1.10.14 +python-dotenv==1.0.1 +python-speech-features==0.6 +pytz==2023.3.post1 +PyWavelets==1.4.1 +PyYAML==6.0.1 +ray==2.6.3 +referencing==0.32.1 +requests-oauthlib==1.3.1 +resampy==0.3.1 +rpds-py==0.17.1 +rsa==4.9 +scikit-image==0.16.2 +scipy==1.10.1 +shapely==2.0.2 +sniffio==1.3.0 +starlette==0.27.0 +structlog==24.1.0 +sympy==1.12 +tensorboard==2.7.0 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +texttable==1.7.0 +torch==1.13.1 +torchaudio==0.13.1 +torchvision==0.14.1 +tqdm==4.66.1 +tzdata==2023.4 +uvicorn==0.27.0.post1 +watchfiles==0.21.0 +websockets==12.0 +Werkzeug==3.0.1 + +成员分工:晏永磊完成了所有工作,因为这个组只有他一人。 diff --git a/checkpoints/README.md b/checkpoints/README.md index 0a1432d6..a77e5892 100644 --- a/checkpoints/README.md +++ b/checkpoints/README.md @@ -1,7 +1,10 @@ 这个文件夹中保存的是,模型训练或验证过程中用到的一些额外的预训练权重如: +- LiveSpeechPortraits中用到的数据放在live_speech_portraits目录下。 +- LiveSpeechPortraits需要的数据请从此处[data - Google 云端硬盘](https://drive.google.com/drive/folders/1sHc2xEEGwnb0h2rkUhG9sPmOxvRvPVpJ)下载,并且放在live_speech_portraits中。 - wav2lip中用到的syncnet权重 - 计算合成视频lip-audio同步LSE用到的syncnet-v2权重 - ....... +- 目录结构为: ``` diff --git a/environment.yml b/environment.yml deleted file mode 100644 index 09a8595b..00000000 --- a/environment.yml +++ /dev/null @@ -1,138 +0,0 @@ -name: torch38 -channels: - - defaults -dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=5.1=1_gnu - - ca-certificates=2023.08.22=h06a4308_0 - - ld_impl_linux-64=2.38=h1181459_1 - - libffi=3.4.4=h6a678d5_0 - - libgcc-ng=11.2.0=h1234567_1 - - libgomp=11.2.0=h1234567_1 - - libstdcxx-ng=11.2.0=h1234567_1 - - ncurses=6.4=h6a678d5_0 - - openssl=3.0.12=h7f8727e_0 - - pip=23.3=py38h06a4308_0 - - python=3.8.18=h955ad1f_0 - - readline=8.2=h5eee18b_0 - - setuptools=68.0.0=py38h06a4308_0 - - sqlite=3.41.2=h5eee18b_0 - - tk=8.6.12=h1ccaba5_0 - - wheel=0.41.2=py38h06a4308_0 - - xz=5.4.2=h5eee18b_0 - - zlib=1.2.13=h5eee18b_0 - - pip: - - absl-py==2.0.0 - - addict==2.4.0 - - aiosignal==1.3.1 - - appdirs==1.4.4 - - attrs==23.1.0 - - audioread==3.0.1 - - basicsr==1.3.4.7 - - cachetools==5.3.2 - - certifi==2020.12.5 - - cffi==1.16.0 - - charset-normalizer==3.3.2 - - click==8.1.7 - - cloudpickle==3.0.0 - - colorama==0.4.6 - - colorlog==6.7.0 - - contourpy==1.1.1 - - cycler==0.12.1 - - decorator==5.1.1 - - dlib==19.22.1 - - docker-pycreds==0.4.0 - - face-alignment==1.3.5 - - ffmpeg==1.4 - - filelock==3.13.1 - - fonttools==4.44.0 - - frozenlist==1.4.0 - - future==0.18.3 - - gitdb==4.0.11 - - gitpython==3.1.40 - - glob2==0.7 - - google-auth==2.23.4 - - google-auth-oauthlib==0.4.6 - - grpcio==1.59.2 - - hyperopt==0.2.5 - - idna==3.4 - - imageio==2.9.0 - - imageio-ffmpeg==0.4.5 - - importlib-metadata==6.8.0 - - importlib-resources==6.1.0 - - joblib==1.3.2 - - jsonschema==4.19.2 - - jsonschema-specifications==2023.7.1 - - kiwisolver==1.4.5 - - kornia==0.5.5 - - lazy-loader==0.3 - - librosa==0.10.1 - - llvmlite==0.37.0 - - lmdb==1.2.1 - - lws==1.2.7 - - markdown==3.5.1 - - markupsafe==2.1.3 - - matplotlib==3.6.3 - - msgpack==1.0.7 - - networkx==3.1 - - numba==0.54.1 - - numpy==1.20.3 - - oauthlib==3.2.2 - - opencv-python==3.4.9.33 - - packaging==23.2 - - pandas==1.3.4 - - pathtools==0.1.2 - - pillow==6.2.1 - - pkgutil-resolve-name==1.3.10 - - platformdirs==3.11.0 - - plotly==5.18.0 - - pooch==1.8.0 - - protobuf==4.25.0 - - psutil==5.9.6 - - pyasn1==0.5.0 - - pyasn1-modules==0.3.0 - - pycparser==2.21 - - pyparsing==3.1.1 - - python-dateutil==2.8.2 - - python-speech-features==0.6 - - pytorch-fid==0.3.0 - - pytz==2023.3.post1 - - pywavelets==1.4.1 - - pyyaml==5.3.1 - - ray==2.6.3 - - referencing==0.30.2 - - requests==2.31.0 - - requests-oauthlib==1.3.1 - - rpds-py==0.12.0 - - rsa==4.9 - - scikit-image==0.16.2 - - scikit-learn==1.3.2 - - scipy==1.5.0 - - sentry-sdk==1.34.0 - - setproctitle==1.3.3 - - six==1.16.0 - - smmap==5.0.1 - - soundfile==0.12.1 - - soxr==0.3.7 - - tabulate==0.9.0 - - tb-nightly==2.12.0a20230126 - - tenacity==8.2.3 - - tensorboard==2.7.0 - - tensorboard-data-server==0.6.1 - - tensorboard-plugin-wit==1.8.1 - - texttable==1.7.0 - - thop==0.1.1-2209072238 - - threadpoolctl==3.2.0 - - tomli==2.0.1 - - torch==1.13.1+cu116 - - torchaudio==0.13.1+cu116 - - torchvision==0.14.1+cu116 - - tqdm==4.66.1 - - trimesh==3.9.20 - - typing-extensions==4.8.0 - - tzdata==2023.3 - - urllib3==2.0.7 - - wandb==0.15.12 - - werkzeug==3.0.1 - - yapf==0.40.2 - - zipp==3.17.0 diff --git a/requirements.txt b/requirements.txt index 1605c1fe..30c7839d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,114 +1,76 @@ -absl-py==2.0.0 -addict==2.4.0 +absl-py==2.1.0 aiosignal==1.3.1 -appdirs==1.4.4 -attrs==23.1.0 -audioread==3.0.1 -basicsr==1.3.4.7 +albumentations==0.5.2 +anyio==4.2.0 +attrs==23.2.0 cachetools==5.3.2 -certifi==2020.12.5 -cffi==1.16.0 -charset-normalizer==3.3.2 click==8.1.7 -cloudpickle==3.0.0 -colorama==0.4.6 +cog==0.9.4 colorlog==6.7.0 -contourpy==1.1.1 -cycler==0.12.1 -decorator==5.1.1 -dlib==19.22.1 -docker-pycreds==0.4.0 -face-alignment==1.3.5 -ffmpeg==1.4 +dominate==2.9.1 +exceptiongroup==1.2.0 +fastapi==0.98.0 filelock==3.13.1 -fonttools==4.44.0 -frozenlist==1.4.0 -future==0.18.3 -gitdb==4.0.11 -GitPython==3.1.40 -glob2==0.7 -google-auth==2.23.4 +fonttools==4.25.0 +frozenlist==1.4.1 +google-auth==2.26.2 google-auth-oauthlib==0.4.6 -grpcio==1.59.2 -hyperopt==0.2.5 -idna==3.4 -imageio==2.9.0 -imageio-ffmpeg==0.4.5 -importlib-metadata==6.8.0 -importlib-resources==6.1.0 -joblib==1.3.2 -jsonschema==4.19.2 -jsonschema-specifications==2023.7.1 -kiwisolver==1.4.5 -kornia==0.5.5 -lazy_loader==0.3 -librosa==0.10.1 -llvmlite==0.37.0 -lmdb==1.2.1 -lws==1.2.7 -Markdown==3.5.1 -MarkupSafe==2.1.3 -matplotlib==3.6.3 -msgpack==1.0.7 +grpcio==1.60.0 +h11==0.14.0 +h5py==3.10.0 +httptools==0.6.1 +imageio==2.33.1 +imgaug==0.4.0 +Jinja2==3.1.3 +jsonschema==4.21.0 +jsonschema-specifications==2023.12.1 +librosa==0.7.0 +llvmlite==0.31.0 +Markdown==3.5.2 +MarkupSafe==2.1.4 +mkl-service==2.4.0 +mpmath==1.3.0 +munkres==1.1.4 networkx==3.1 -numba==0.54.1 +numba==0.48.0 numpy==1.20.3 oauthlib==3.2.2 opencv-python==3.4.9.33 -packaging==23.2 +opencv-python-headless==4.9.0.80 pandas==1.3.4 -pathtools==0.1.2 -Pillow==6.2.1 pkgutil_resolve_name==1.3.10 -platformdirs==3.11.0 -plotly==5.18.0 -pooch==1.8.0 -protobuf==4.25.0 -psutil==5.9.6 -pyasn1==0.5.0 +protobuf==3.19.0 +pyasn1==0.5.1 pyasn1-modules==0.3.0 -pycparser==2.21 -pyparsing==3.1.1 -python-dateutil==2.8.2 +pydantic==1.10.14 +python-dotenv==1.0.1 python-speech-features==0.6 -pytorch-fid==0.3.0 pytz==2023.3.post1 PyWavelets==1.4.1 -PyYAML==5.3.1 +PyYAML==6.0.1 ray==2.6.3 -referencing==0.30.2 -requests==2.31.0 +referencing==0.32.1 requests-oauthlib==1.3.1 -rpds-py==0.12.0 +resampy==0.3.1 +rpds-py==0.17.1 rsa==4.9 scikit-image==0.16.2 -scikit-learn==1.3.2 -scipy==1.5.0 -sentry-sdk==1.34.0 -setproctitle==1.3.3 -six==1.16.0 -smmap==5.0.1 -soundfile==0.12.1 -soxr==0.3.7 -tabulate==0.9.0 -tb-nightly==2.12.0a20230126 -tenacity==8.2.3 +scipy==1.10.1 +shapely==2.0.2 +sniffio==1.3.0 +starlette==0.27.0 +structlog==24.1.0 +sympy==1.12 tensorboard==2.7.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 texttable==1.7.0 -thop==0.1.1.post2209072238 -threadpoolctl==3.2.0 -tomli==2.0.1 -torch==1.13.1+cu116 -torchaudio==0.13.1+cu116 -torchvision==0.14.1+cu116 +torch==1.13.1 +torchaudio==0.13.1 +torchvision==0.14.1 tqdm==4.66.1 -trimesh==3.9.20 -typing_extensions==4.8.0 -tzdata==2023.3 -urllib3==2.0.7 -wandb==0.15.12 +tzdata==2023.4 +uvicorn==0.27.0.post1 +watchfiles==0.21.0 +websockets==12.0 Werkzeug==3.0.1 -yapf==0.40.2 -zipp==3.17.0 diff --git a/run_talkingface.py b/run_talkingface.py index 3989d566..9f355a49 100644 --- a/run_talkingface.py +++ b/run_talkingface.py @@ -1,4 +1,5 @@ import argparse +import torch from talkingface.quick_start import run if __name__ == "__main__": diff --git a/talkingface/config/configurator.py b/talkingface/config/configurator.py index 7b6e21d8..0cdce648 100644 --- a/talkingface/config/configurator.py +++ b/talkingface/config/configurator.py @@ -73,7 +73,7 @@ def __init__( self._load_internal_config_dict(self.model, self.model_class, self.dataset) self.final_config_dict = self._get_final_config_dict() self._set_default_parameters() - self._init_device() + # self._init_device() def _init_parameters_category(self): self.parameters = dict() diff --git a/talkingface/data/dataprocess/LiveSpeechPortraits_process.py b/talkingface/data/dataprocess/LiveSpeechPortraits_process.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface/data/dataset/LiveSpeechPortraits/__init__.py b/talkingface/data/dataset/LiveSpeechPortraits/__init__.py new file mode 100644 index 00000000..93a0b38c --- /dev/null +++ b/talkingface/data/dataset/LiveSpeechPortraits/__init__.py @@ -0,0 +1,93 @@ +"""This package includes all the modules related to data loading and preprocessing + + To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. + You need to implement four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point from data loader. + -- : (optionally) add dataset-specific options and set default options. + +Now you can use the dataset class by specifying flag '--dataset_mode dummy'. +See our template dataset class 'template_dataset.py' for more details. +""" +import importlib +import torch.utils.data +from talkingface.data.dataset.LiveSpeechPortraits.base_dataset import BaseDataset + + +def find_dataset_using_name(dataset_name): + """Import the module "data/[dataset_name]_dataset.py". + + In the file, the class called DatasetNameDataset() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_filename = "talkingface.data.dataset.LiveSpeechPortraits." + dataset_name + "_dataset" + datasetlib = importlib.import_module(dataset_filename) + + dataset = None + target_dataset_name = dataset_name.replace('_', '') + 'dataset' + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() \ + and issubclass(cls, BaseDataset): + dataset = cls + + if dataset is None: + raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) + + return dataset + + +def get_option_setter(dataset_name): + """Return the static method of the dataset class.""" + dataset_class = find_dataset_using_name(dataset_name) + return dataset_class.modify_commandline_options + + +def create_dataset(opt): + """Create a dataset given the option. + + This function wraps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: +# >>> from data import create_dataset +# >>> dataset = create_dataset(opt) + """ + data_loader = CustomDatasetDataLoader(opt) + dataset = data_loader.load_data() + return dataset + + +class CustomDatasetDataLoader(): + """Wrapper class of Dataset class that performs multi-threaded data loading""" + + def __init__(self, opt): + """Initialize this class + + Step 1: create a dataset instance given the name [dataset_mode] + Step 2: create a multi-threaded data loader. + """ + self.opt = opt + dataset_class = find_dataset_using_name(opt.dataset_mode) + self.dataset = dataset_class(opt) + print("dataset [%s] was created" % type(self.dataset).__name__) + self.dataloader = torch.utils.data.DataLoader( + self.dataset, + batch_size=opt.batch_size, + shuffle=not opt.serial_batches, + num_workers=int(opt.num_threads)) + + def load_data(self): + return self + + def __len__(self): + """Return the number of data in the dataset""" + return min(len(self.dataset), self.opt.max_dataset_size) + + def __iter__(self): + """Return a batch of data""" + for i, data in enumerate(self.dataloader): + if i * self.opt.batch_size >= self.opt.max_dataset_size: + break + yield data diff --git a/talkingface/data/dataset/LiveSpeechPortraits/audiovisual_dataset.py b/talkingface/data/dataset/LiveSpeechPortraits/audiovisual_dataset.py new file mode 100644 index 00000000..784700c6 --- /dev/null +++ b/talkingface/data/dataset/LiveSpeechPortraits/audiovisual_dataset.py @@ -0,0 +1,301 @@ +import sys +sys.path.append("..") + +import scipy.io as sio +import torch +import librosa +import bisect +import os +import numpy as np +from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.networks import APC_encoder +from talkingface.data.dataset.LiveSpeechPortraits.base_dataset import BaseDataset +from talkingface.utils.live_speech_portraits import utils + + + +class AudioVisualDataset(BaseDataset): + """ audio-visual dataset. currently, return 2D info and 3D tracking info. + + # for wavenet: + # |----receptive_field----| + # |--output_length--| + # example: | | | | | | | | | | | | | | | | | | | | | + # target: | | | | | | | | | | + + """ + def __init__(self, opt): + # save the option and dataset root + BaseDataset.__init__(self, opt) + + self.isTrain = self.opt.isTrain + self.state = opt.dataset_type + self.dataset_name = opt.dataset_names + self.target_length = opt.time_frame_length + self.sample_rate = opt.sample_rate + self.fps = opt.FPS + + self.audioRF_history = opt.audioRF_history + self.audioRF_future = opt.audioRF_future + self.compute_mel_online = opt.compute_mel_online + self.feature_name = opt.feature_name + + self.audio_samples_one_frame = self.sample_rate / self.fps + self.frame_jump_stride = opt.frame_jump_stride + self.augment = False + self.task = opt.task + self.item_length_audio = int((self.audioRF_history + self.audioRF_future)/ self.fps * self.sample_rate) + + if self.task == 'Audio2Feature': + if opt.feature_decoder == 'WaveNet': + self.A2L_receptive_field = opt.A2L_receptive_field + self.A2L_item_length = self.A2L_receptive_field + self.target_length - 1 + elif opt.feature_decoder == 'LSTM': + self.A2L_receptive_field = 30 + self.A2L_item_length = self.A2L_receptive_field + self.target_length - 1 + elif self.task == 'Audio2Headpose': + self.A2H_receptive_field = opt.A2H_receptive_field + self.A2H_item_length = self.A2H_receptive_field + self.target_length - 1 + self.audio_window = opt.audio_windows + self.half_audio_win = int(self.audio_window / 2) + + self.frame_future = opt.frame_future + self.predict_length = opt.predict_length + self.predict_len = int((self.predict_length - 1) / 2) + + self.gpu_ids = opt.gpu_ids + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + print('self.device:', self.device) + if self.task == 'Audio2Feature': + self.seq_len = opt.sequence_length + + self.total_len = 0 + self.dataset_root = os.path.join(self.root, self.dataset_name) + if self.state == 'Train': + self.clip_names = opt.train_dataset_names + elif self.state == 'Val': + self.clip_names = opt.validate_dataset_names + elif self.state == 'Test': + self.clip_names = opt.test_dataset_names + + + self.clip_nums = len(self.clip_names) + # main info + self.audio = [''] * self.clip_nums + self.audio_features = [''] * self.clip_nums + self.feats = [''] * self.clip_nums + self.exps = [''] * self.clip_nums + self.pts3d = [''] * self.clip_nums + self.rot_angles = [''] * self.clip_nums + self.trans = [''] * self.clip_nums + self.headposes = [''] * self.clip_nums + self.velocity_pose = [''] * self.clip_nums + self.acceleration_pose = [''] * self.clip_nums + self.mean_trans = [''] * self.clip_nums + if self.state == 'Test': + self.landmarks = [''] * self.clip_nums + # meta info + self.start_point = [''] * self.clip_nums + self.end_point = [''] * self.clip_nums + self.len = [''] * self.clip_nums + self.sample_start = [] + self.clip_valid = ['True'] * self.clip_nums + self.invalid_clip = [] + + + self.mouth_related_indices = np.concatenate([np.arange(4, 11), np.arange(46, 64)]) + if self.task == 'Audio2Feature': + if self.opt.only_mouth: + self.indices = self.mouth_related_indices + else: + self.indices = np.arange(73) + if opt.use_delta_pts: + self.pts3d_mean = np.load(os.path.join(self.dataset_root, 'mean_pts3d.npy')) + + for i in range(self.clip_nums): + name = self.clip_names[i] + clip_root = os.path.join(self.dataset_root, name) + # audio + if os.path.exists(os.path.join(clip_root, name + '_denoise.wav')): + audio_path = os.path.join(clip_root, name + '_denoise.wav') + print('find denoised wav!') + else: + audio_path = os.path.join(clip_root, name + '.wav') + self.audio[i], _ = librosa.load(audio_path, sr=self.sample_rate) + + if self.opt.audio_encoder == 'APC': + APC_name = os.path.split(self.opt.APC_model_path)[-1] + APC_feature_file = name + '_APC_feature_V0324_ckpt_{}.npy'.format(APC_name) + APC_feature_path = os.path.join(clip_root, APC_feature_file) + need_deepfeats = False if os.path.exists(APC_feature_path) else True + if not need_deepfeats: + self.audio_features[i] = np.load(APC_feature_path).astype(np.float32) + else: + need_deepfeats = False + + + # 3D landmarks & headposes + if self.task == 'Audio2Feature': + self.start_point[i] = 0 + elif self.task == 'Audio2Headpose': + self.start_point[i] = 300 + fit_data_path = os.path.join(clip_root, '3d_fit_data.npz') + fit_data = np.load(fit_data_path) + if not opt.ispts_norm: + ori_pts3d = fit_data['pts_3d'].astype(np.float32) + else: + ori_pts3d = np.load(os.path.join(clip_root, 'tracked3D_normalized_pts_fix_contour.npy')) + if opt.use_delta_pts: + self.pts3d[i] = ori_pts3d - self.pts3d_mean + else: + self.pts3d[i] = ori_pts3d + if opt.feature_dtype == 'pts3d': + self.feats[i] = self.pts3d[i] + elif opt.feature_dtype == 'FW': + track_data_path = os.path.join(clip_root, 'tracking_results.mat') + self.feats[i] = sio.loadmat(track_data_path)['exps'].astype(np.float32) + self.rot_angles[i] = fit_data['rot_angles'].astype(np.float32) + # change -180~180 to 0~360 + if not self.dataset_name == 'Yuxuan': + rot_change = self.rot_angles[i][:, 0] < 0 + self.rot_angles[i][rot_change, 0] += 360 + self.rot_angles[i][:,0] -= 180 # change x axis direction + # use delta translation + self.mean_trans[i] = fit_data['trans'][:,:,0].astype(np.float32).mean(axis=0) + self.trans[i] = fit_data['trans'][:,:,0].astype(np.float32) - self.mean_trans[i] + + self.headposes[i] = np.concatenate([self.rot_angles[i], self.trans[i]], axis=1) + self.velocity_pose[i] = np.concatenate([np.zeros(6)[None,:], self.headposes[i][1:] - self.headposes[i][:-1]]) + self.acceleration_pose[i] = np.concatenate([np.zeros(6)[None,:], self.velocity_pose[i][1:] - self.velocity_pose[i][:-1]]) + + if self.dataset_name == 'Yuxuan': + total_frames = self.feats[i].shape[0] - 300 - 130 + else: + total_frames = self.feats[i].shape[0] - 60 + + + if need_deepfeats: + if self.opt.audio_encoder == 'APC': + print('dataset {} need to pre-compute APC features ...'.format(name)) + print('first we compute mel spectram for dataset {} '.format(name)) + mel80 = utils.compute_mel_one_sequence(self.audio[i]) + mel_nframe = mel80.shape[0] + print('loading pre-trained model: ', self.opt.APC_model_path) + APC_model = APC_encoder(self.opt.audiofeature_input_channels, + self.opt.APC_hidden_size, + self.opt.APC_rnn_layers, + self.opt.APC_residual) + APC_model.load_state_dict(torch.load(self.opt.APC_model_path, map_location=str(self.device)), strict=False) +# APC_model.load_state_dict(torch.load(self.opt.APC_model_path), strict=False) + APC_model.cuda() + APC_model.eval() + with torch.no_grad(): + length = torch.Tensor([mel_nframe]) +# hidden_reps = torch.zeros([mel_nframe, self.opt.APC_hidden_size]).cuda() + mel80_torch = torch.from_numpy(mel80.astype(np.float32)).cuda().unsqueeze(0) + hidden_reps = APC_model.forward(mel80_torch, length)[0] # [mel_nframe, 512] + hidden_reps = hidden_reps.cpu().numpy() + np.save(APC_feature_path, hidden_reps) + self.audio_features[i] = hidden_reps + + + valid_frames = total_frames - self.start_point[i] + self.len[i] = valid_frames - 400 + if i == 0: + self.sample_start.append(0) + else: + self.sample_start.append(self.sample_start[-1] + self.len[i-1] - 1) + self.total_len += np.int32(np.floor(self.len[i] / self.frame_jump_stride)) + + + + def __getitem__(self, index): + # recover real index from compressed one + index_real = np.int32(index * self.frame_jump_stride) + # find which audio file and the start frame index + file_index = bisect.bisect_right(self.sample_start, index_real) - 1 + current_frame = index_real - self.sample_start[file_index] + self.start_point[file_index] + current_target_length = self.target_length + + if self.task == 'Audio2Feature': + # start point is current frame + A2Lsamples = self.audio_features[file_index][current_frame * 2 : (current_frame + self.seq_len) * 2] + target_pts3d = self.feats[file_index][current_frame : current_frame + self.seq_len, self.indices].reshape(self.seq_len, -1) + + A2Lsamples = torch.from_numpy(A2Lsamples).float() + target_pts3d = torch.from_numpy(target_pts3d).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Lsamples, target_pts3d + + + elif self.task == 'Audio2Headpose': + if self.opt.feature_decoder == 'WaveNet': + # find the history info start points + A2H_history_start = current_frame - self.A2H_receptive_field + A2H_item_length = self.A2H_item_length + A2H_receptive_field = self.A2H_receptive_field + + if self.half_audio_win == 1: + A2Hsamples = self.audio_features[file_index][2 * (A2H_history_start + self.frame_future) : 2 * (A2H_history_start + self.frame_future + A2H_item_length)] + else: + A2Hsamples = np.zeros([A2H_item_length, self.audio_window, 512]) + for i in range(A2H_item_length): + A2Hsamples[i] = self.audio_features[file_index][2 * (A2H_history_start + i) - self.half_audio_win : 2 * (A2H_history_start + i) + self.half_audio_win] + + if self.predict_len == 0: + target_headpose = self.headposes[file_index][A2H_history_start + A2H_receptive_field : A2H_history_start + A2H_item_length + 1] + history_headpose = self.headposes[file_index][A2H_history_start : A2H_history_start + A2H_item_length].reshape(A2H_item_length, -1) + + target_velocity = self.velocity_pose[file_index][A2H_history_start + A2H_receptive_field : A2H_history_start + A2H_item_length + 1] + history_velocity = self.velocity_pose[file_index][A2H_history_start : A2H_history_start + A2H_item_length].reshape(A2H_item_length, -1) + target_info = torch.from_numpy(np.concatenate([target_headpose, target_velocity], axis=1).reshape(current_target_length, -1)).float() + else: + history_headpose = self.headposes[file_index][A2H_history_start : A2H_history_start + A2H_item_length].reshape(A2H_item_length, -1) + history_velocity = self.velocity_pose[file_index][A2H_history_start : A2H_history_start + A2H_item_length].reshape(A2H_item_length, -1) + + + target_headpose_ = self.headposes[file_index][A2H_history_start + A2H_receptive_field - self.predict_len : A2H_history_start + A2H_item_length + 1 + self.predict_len + 1] + target_headpose = np.zeros([current_target_length, self.predict_length, target_headpose_.shape[1]]) + for i in range(current_target_length): + target_headpose[i] = target_headpose_[i : i + self.predict_length] + target_headpose = target_headpose#.reshape(current_target_length, -1, order='F') + + target_velocity_ = self.headposes[file_index][A2H_history_start + A2H_receptive_field - self.predict_len : A2H_history_start + A2H_item_length + 1 + self.predict_len + 1] + target_velocity = np.zeros([current_target_length, self.predict_length, target_velocity_.shape[1]]) + for i in range(current_target_length): + target_velocity[i] = target_velocity_[i : i + self.predict_length] + target_velocity = target_velocity#.reshape(current_target_length, -1, order='F') + + target_info = torch.from_numpy(np.concatenate([target_headpose, target_velocity], axis=2).reshape(current_target_length, -1)).float() + + A2Hsamples = torch.from_numpy(A2Hsamples).float() + + + history_info = torch.from_numpy(np.concatenate([history_headpose, history_velocity], axis=1)).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Hsamples, history_info, target_info + + + elif self.opt.feature_decoder == 'LSTM': + A2Hsamples = self.audio_features[file_index][current_frame * 2 : (current_frame + self.opt.A2H_receptive_field) * 2] + + target_headpose = self.headposes[file_index][current_frame : current_frame + self.opt.A2H_receptive_field] + target_velocity = self.velocity_pose[file_index][current_frame : current_frame + self.opt.A2H_receptive_field] + target_info = torch.from_numpy(np.concatenate([target_headpose, target_velocity], axis=1).reshape(self.opt.A2H_receptive_field, -1)).float() + + A2Hsamples = torch.from_numpy(A2Hsamples).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Hsamples, target_info + + + + def __len__(self): + return self.total_len + + + + + + diff --git a/talkingface/data/dataset/LiveSpeechPortraits/base_dataset.py b/talkingface/data/dataset/LiveSpeechPortraits/base_dataset.py new file mode 100644 index 00000000..3936feff --- /dev/null +++ b/talkingface/data/dataset/LiveSpeechPortraits/base_dataset.py @@ -0,0 +1,65 @@ +"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" + +import numpy as np +import torch.utils.data as data +from PIL import Image +import torchvision.transforms as transforms +from abc import ABC, abstractmethod + + +class BaseDataset(data.Dataset, ABC): + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + self.root = opt.dataroot + + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + + + diff --git a/talkingface/data/dataset/LiveSpeechPortraits/face_dataset.py b/talkingface/data/dataset/LiveSpeechPortraits/face_dataset.py new file mode 100644 index 00000000..42e22337 --- /dev/null +++ b/talkingface/data/dataset/LiveSpeechPortraits/face_dataset.py @@ -0,0 +1,376 @@ +import os +from talkingface.data.dataset.LiveSpeechPortraits.base_dataset import BaseDataset +import os.path +from pathlib import Path +import torch +from skimage.io import imread, imsave +from PIL import Image +import bisect +import numpy as np +import io +import cv2 +import h5py +import albumentations as A + + +class FaceDataset(BaseDataset): + def __init__(self, opt): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + + A few things can be done here. + - save the options (have been done in BaseDataset) + - get image paths and meta information of the dataset. + - define the image transformation. + """ + BaseDataset.__init__(self, opt) + + self.state = 'Train' if self.opt.isTrain else 'Test' + self.dataset_name = opt.dataset_names[0] + + # default settings + # currently, we have 8 parts for face parts + self.part_list = [[list(range(0, 15))], # contour + [[15,16,17,18,18,19,20,15]], # right eyebrow + [[21,22,23,24,24,25,26,21]], # left eyebrow + [range(35, 44)], # nose + [[27,65,28,68,29], [29,67,30,66,27]], # right eye + [[33,69,32,72,31], [31,71,34,70,33]], # left eye + [range(46, 53), [52,53,54,55,56,57,46]], # mouth + [[46,63,62,61,52], [52,60,59,58,46]] # tongue + ] + self.mouth_outer = [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 46] + self.label_list = [1, 1, 2, 3, 3, 4, 5] # labeling for different facial parts + + # only load in train mode + + self.dataset_root = os.path.join(self.root, self.dataset_name) + if self.state == 'Train': + self.clip_names = opt.train_dataset_names + elif self.state == 'Val': + self.clip_names = opt.validate_dataset_names + elif self.state == 'Test': + self.clip_names = opt.test_dataset_names + + self.clip_nums = len(self.clip_names) + + # load pts & image info + self.landmarks2D, self.len, self.sample_len = [''] * self.clip_nums, [''] * self.clip_nums, [''] * self.clip_nums + self.image_transforms, self.image_pad, self.tgts_paths = [''] * self.clip_nums, [''] * self.clip_nums, [''] * self.clip_nums + self.shoulders, self.shoulder3D = [''] * self.clip_nums, [''] * self.clip_nums + self.sample_start = [] + + # tracked 3d info & candidates images + self.pts3d, self.rot, self.trans = [''] * self.clip_nums, [''] * self.clip_nums, [''] * self.clip_nums + self.full_cand = [''] * self.clip_nums + self.headposes = [''] * self.clip_nums + + self.total_len = 0 + if self.opt.isTrain: + for i in range(self.clip_nums): + name = self.clip_names[i] + clip_root = os.path.join(self.dataset_root, name) + # basic image info + img_file_path = os.path.join(clip_root, name + '.h5') + img_file = h5py.File(img_file_path, 'r')[name] + example = np.asarray(Image.open(io.BytesIO(img_file[0]))) + h, w, _ = example.shape + + + landmark_path = os.path.join(clip_root, 'tracked2D_normalized_pts_fix_contour.npy') + self.landmarks2D[i] = np.load(landmark_path).astype(np.float32) + change_paras = np.load(os.path.join(clip_root, 'change_paras.npz')) + scale, xc, yc = change_paras['scale'], change_paras['xc'], change_paras['yc'] + x_min, x_max, y_min, y_max = xc-256, xc+256, yc-256, yc+256 + # if need padding + x_min, x_max, y_min, y_max, self.image_pad[i] = max(x_min, 0), min(x_max, w), max(y_min, 0), min(y_max, h), None + + + if x_min == 0 or x_max == 512 or y_min == 0 or y_max == 512: + top, bottom, left, right = abs(yc-256-y_min), abs(yc+256-y_max), abs(xc-256-x_min), abs(xc+256-x_max) + self.image_pad[i] = [top, bottom, left, right] + self.image_transforms[i] = A.Compose([ + A.Resize(np.int32(h*scale), np.int32(w*scale)), + A.Crop(x_min, y_min, x_max, y_max)]) + + if self.opt.isH5: + tgt_file_path = os.path.join(clip_root, name + '.h5') + tgt_file = h5py.File(tgt_file_path, 'r')[name] + image_length = len(tgt_file) + else: + tgt_paths = list(map(lambda x:str(x), sorted(list(Path(clip_root).glob('*'+self.opt.suffix)), key=lambda x: int(x.stem)))) + image_length = len(tgt_paths) + self.tgts_paths[i] = tgt_paths + if not self.landmarks2D[i].shape[0] == image_length: + raise ValueError('In dataset {} length of landmarks and images are not equal!'.format(name)) + + # tracked 3d info + fit_data_path = os.path.join(clip_root, '3d_fit_data.npz') + fit_data = np.load(fit_data_path) + self.pts3d[i] = fit_data['pts_3d'].astype(np.float32) + self.rot[i] = fit_data['rot_angles'].astype(np.float32) + self.trans[i] = fit_data['trans'][:,:,0].astype(np.float32) + if not self.pts3d[i].shape[0] == image_length: + raise ValueError('In dataset {} length of 3d pts and images are not equal!'.format(name)) + + # candidates images + + tmp = [] + for j in range(4): + try: + output = imread(os.path.join(clip_root, 'candidates', f'normalized_full_{j}.jpg')) + except: + imgc = imread(os.path.join(clip_root, 'candidates', f'full_{j}.jpg')) + output = self.common_dataset_transform(imgc, i) + imsave(os.path.join(clip_root, 'candidates', f'normalized_full_{j}.jpg'), output) + output = A.pytorch.transforms.ToTensor(normalize={'mean':(0.5,0.5,0.5), 'std':(0.5,0.5,0.5)})(image=output)['image'] + tmp.append(output) + self.full_cand[i] = torch.cat(tmp) + + # headpose + fit_data_path = os.path.join(clip_root, '3d_fit_data.npz') + fit_data = np.load(fit_data_path) + rot_angles = fit_data['rot_angles'].astype(np.float32) + # change -180~180 to 0~360 + if not self.dataset_name == 'Yuxuan': + rot_change = rot_angles[:, 0] < 0 + rot_angles[rot_change, 0] += 360 + rot_angles[:,0] -= 180 # change x axis direction + # use delta translation + mean_trans = fit_data['trans'][:,:,0].astype(np.float32).mean(axis=0) + trans = fit_data['trans'][:,:,0].astype(np.float32) - mean_trans + + self.headposes[i] = np.concatenate([rot_angles, trans], axis=1) + + # shoulders + shoulder_path = os.path.join(clip_root, 'normalized_shoulder_points.npy') + self.shoulders[i] = np.load(shoulder_path) + shoulder3D_path = os.path.join(clip_root, 'shoulder_points3D.npy') + self.shoulder3D[i] = np.load(shoulder3D_path) + + + self.sample_len[i] = np.int32(np.floor((self.landmarks2D[i].shape[0] - 60) / self.opt.frame_jump) + 1) + self.len[i] = self.landmarks2D[i].shape[0] + if i == 0: + self.sample_start.append(0) + else: + self.sample_start.append(self.sample_start[-1] + self.sample_len[i-1]) # not minus 1 + self.total_len += self.sample_len[i] + + # test mode + else: + # if need padding + example = imread(os.path.join(self.root, 'example.png')) + h, w, _ = example.shape + change_paras = np.load(os.path.join(self.root, 'change_paras.npz')) + scale, xc, yc = change_paras['scale'], change_paras['xc'], change_paras['yc'] + x_min, x_max, y_min, y_max = xc-256, xc+256, yc-256, yc+256 + x_min, x_max, y_min, y_max, self.image_pad = max(x_min, 0), min(x_max, w), max(y_min, 0), min(y_max, h), None + + + if x_min == 0 or x_max == 512 or y_min == 0 or y_max == 512: + top, bottom, left, right = abs(yc-256-y_min), abs(yc+256-y_max), abs(xc-256-x_min), abs(xc+256-x_max) + self.image_pad = [top, bottom, left, right] + + + + + + def __getitem__(self, ind): + dataset_index = bisect.bisect_right(self.sample_start, ind) - 1 + data_index = (ind - self.sample_start[dataset_index]) * self.opt.frame_jump + np.random.randint(self.opt.frame_jump) + + target_ind = data_index + 1 # history_ind, current_ind + landmarks = self.landmarks2D[dataset_index][target_ind] # [73, 2] + shoulders = self.shoulders[dataset_index][target_ind].copy() + + dataset_name = self.clip_names[dataset_index] + clip_root = os.path.join(self.dataset_root, dataset_name) + if self.opt.isH5: + tgt_file_path = os.path.join(clip_root, dataset_name + '.h5') + tgt_file = h5py.File(tgt_file_path, 'r')[dataset_name] + tgt_image = np.asarray(Image.open(io.BytesIO(tgt_file[target_ind]))) + + # do transform + tgt_image = self.common_dataset_transform(tgt_image, dataset_index, None) + else: + pass + + h, w, _ = tgt_image.shape + + ### transformations & online data augmentations on images and landmarks + self.get_crop_coords(landmarks, (w, h), dataset_name, random_trans_scale=0) # 30.5 µs ± 348 ns random translation + + transform_tgt = self.get_transform(dataset_name, True, n_img=1, n_keypoint=1, flip=False) + transformed_tgt = transform_tgt(image=tgt_image, keypoints=landmarks) + + tgt_image, points = transformed_tgt['image'], np.array(transformed_tgt['keypoints']).astype(np.float32) + + feature_map = self.get_feature_image(points, (self.opt.loadSize, self.opt.loadSize), shoulders, self.image_pad[dataset_index])[np.newaxis, :].astype(np.float32)/255. + feature_map = torch.from_numpy(feature_map) + + ## facial weight mask + weight_mask = self.generate_facial_weight_mask(points, h, w)[None, :] + + cand_image = self.full_cand[dataset_index] + + return_list = {'feature_map': feature_map, 'cand_image': cand_image, 'tgt_image': tgt_image, 'weight_mask': weight_mask} + + return return_list + + + + + def common_dataset_transform(self, input, i): + output = self.image_transforms[i](image=input)['image'] + if self.image_pad[i] is not None: + top, bottom, left, right = self.image_pad[i] + output = cv2.copyMakeBorder(output, top, bottom, left, right, cv2.BORDER_CONSTANT, value = 0) + return output + + + + def generate_facial_weight_mask(self, points, h = 512, w = 512): + mouth_mask = np.zeros([512, 512, 1]) + points = points[self.mouth_outer] + points = np.int32(points) + mouth_mask = cv2.fillPoly(mouth_mask, [points], (255,0,0)) +# plt.imshow(mouth_mask[:,:,0]) + mouth_mask = cv2.dilate(mouth_mask, np.ones((45, 45))) / 255 + + return mouth_mask.astype(np.float32) + + + + def get_transform(self, dataset_name, keypoints=False, n_img=1, n_keypoint=1, normalize=True, flip=False): + min_x = getattr(self, 'min_x_' + str(dataset_name)) + max_x = getattr(self, 'max_x_' + str(dataset_name)) + min_y = getattr(self, 'min_y_' + str(dataset_name)) + max_y = getattr(self, 'max_y_' + str(dataset_name)) + + additional_flag = False + additional_targets_dict = {} + if n_img > 1: + additional_flag = True + image_str = ['image' + str(i) for i in range(0, n_img)] + for i in range(n_img): + additional_targets_dict[image_str[i]] = 'image' + if n_keypoint > 1: + additional_flag = True + keypoint_str = ['keypoint' + str(i) for i in range(0, n_keypoint)] + for i in range(n_keypoint): + additional_targets_dict[keypoint_str[i]] = 'keypoints' + + transform = A.Compose([ + A.Crop(x_min=min_x, x_max=max_x, y_min=min_y, y_max=max_y), + A.Resize(self.opt.loadSize, self.opt.loadSize), + A.HorizontalFlip(p=flip), + A.pytorch.transforms.ToTensor(normalize={'mean':(0.5,0.5,0.5), 'std':(0.5,0.5,0.5)} if normalize==True else None)], + keypoint_params=A.KeypointParams(format='xy', remove_invisible=False) if keypoints==True else None, + additional_targets=additional_targets_dict if additional_flag == True else None + ) + return transform + + + def get_data_test_mode(self, landmarks, shoulder, pad=None): + ''' get transformed data + ''' + + feature_map = torch.from_numpy(self.get_feature_image(landmarks, (self.opt.loadSize, self.opt.loadSize), shoulder, pad)[np.newaxis, :].astype(np.float32)/255.) + + return feature_map + + + def get_feature_image(self, landmarks, size, shoulders=None, image_pad=None): + # draw edges + im_edges = self.draw_face_feature_maps(landmarks, size) + if shoulders is not None: + if image_pad is not None: + top, bottom, left, right = image_pad + delta_y = top - bottom + delta_x = right - left + shoulders[:, 0] += delta_x + shoulders[:, 1] += delta_y + im_edges = self.draw_shoulder_points(im_edges, shoulders) + + + return im_edges + + + def draw_shoulder_points(self, img, shoulder_points): + num = int(shoulder_points.shape[0] / 2) + for i in range(2): + for j in range(num - 1): + pt1 = [int(flt) for flt in shoulder_points[i * num + j]] + pt2 = [int(flt) for flt in shoulder_points[i * num + j + 1]] + img = cv2.line(img, tuple(pt1), tuple(pt2), 255, 2) # BGR + + return img + + + def draw_face_feature_maps(self, keypoints, size=(512, 512)): + w, h = size + # edge map for face region from keypoints + im_edges = np.zeros((h, w), np.uint8) # edge map for all edges + for edge_list in self.part_list: + for edge in edge_list: + for i in range(len(edge)-1): + pt1 = [int(flt) for flt in keypoints[edge[i]]] + pt2 = [int(flt) for flt in keypoints[edge[i + 1]]] + im_edges = cv2.line(im_edges, tuple(pt1), tuple(pt2), 255, 2) + + return im_edges + + + def get_crop_coords(self, keypoints, size, dataset_name, random_trans_scale=50): + # cut a rought region for fine cutting + # here x towards right and y towards down, origin is left-up corner + w_ori, h_ori = size + min_y, max_y = keypoints[:,1].min(), keypoints[:,1].max() + min_x, max_x = keypoints[:,0].min(), keypoints[:,0].max() + xc = (min_x + max_x) // 2 + yc = (min_y*3 + max_y) // 4 + h = w = min((max_x - min_x) * 2, w_ori, h_ori) + + if self.opt.isTrain: + # do online augment on landmarks & images + # 1. random translation: move 10% + x_bias, y_bias = np.random.uniform(-random_trans_scale, random_trans_scale, size=(2,)) + xc, yc = xc + x_bias, yc + y_bias + + # modify the center x, center y to valid position + xc = min(max(0, xc - w//2) + w, w_ori) - w//2 + yc = min(max(0, yc - h//2) + h, h_ori) - h//2 + + min_x, max_x = xc - w//2, xc + w//2 + min_y, max_y = yc - h//2, yc + h//2 + + setattr(self, 'min_x_' + str(dataset_name), int(min_x)) + setattr(self, 'max_x_' + str(dataset_name), int(max_x)) + setattr(self, 'min_y_' + str(dataset_name), int(min_y)) + setattr(self, 'max_y_' + str(dataset_name), int(max_y)) + + + def crop(self, img, dataset_name): + min_x = getattr(self, 'min_x_' + str(dataset_name)) + max_x = getattr(self, 'max_x_' + str(dataset_name)) + min_y = getattr(self, 'min_y_' + str(dataset_name)) + max_y = getattr(self, 'max_y_' + str(dataset_name)) + if isinstance(img, np.ndarray): + return img[min_y:max_y, min_x:max_x] + else: + return img.crop((min_x, min_y, max_x, max_y)) + + + def __len__(self): + if self.opt.isTrain: + return self.total_len + else: + return 1 + + def name(self): + return 'FaceDataset' + + diff --git a/talkingface/data/dataset/__init__.py b/talkingface/data/dataset/__init__.py index 3fd37538..36315af0 100644 --- a/talkingface/data/dataset/__init__.py +++ b/talkingface/data/dataset/__init__.py @@ -1,2 +1,3 @@ -from talkingface.data.dataset.wav2lip_dataset import Wav2LipDataset -from talkingface.data.dataset.dataset import Dataset \ No newline at end of file +from talkingface.data.dataset.dataset import Dataset +from talkingface.data.dataset.live_speech_portraits_dataset import * +from talkingface.data.dataset.wav2lip_dataset import Wav2LipDataset \ No newline at end of file diff --git a/talkingface/data/dataset/dataset.py b/talkingface/data/dataset/dataset.py index 2de27bd7..9ffb6bad 100644 --- a/talkingface/data/dataset/dataset.py +++ b/talkingface/data/dataset/dataset.py @@ -1,4 +1,14 @@ +"""This module implements an abstract base class (ABC) 'dataset' for datasets. + +It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. +""" import torch +import torch.utils.data as data +from abc import ABC, abstractmethod +from PIL import Image +import torchvision.transforms as transforms +import numpy as np + class Dataset(torch.utils.data.Dataset): def __init__(self, config, datasplit): @@ -13,13 +23,32 @@ def __init__(self, config, datasplit): self.config = config self.split = datasplit - def __getitem__(self): + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. - """ - Returns: - data: dict, 必须是一个字典格式, 具体数据解析在model文件里解析 + Parameters: + index - - a random integer for data indexing + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. """ + pass - raise NotImplementedError \ No newline at end of file + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser \ No newline at end of file diff --git a/talkingface/data/dataset/live_speech_portraits_dataset.py b/talkingface/data/dataset/live_speech_portraits_dataset.py new file mode 100644 index 00000000..1eab5e2e --- /dev/null +++ b/talkingface/data/dataset/live_speech_portraits_dataset.py @@ -0,0 +1,855 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import lr_scheduler +from torch.nn import init +import functools +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence +import numpy as np +from collections import OrderedDict +from torch.cuda.amp import autocast as autocast +from abc import ABC, abstractmethod +from talkingface.data.dataset.dataset import Dataset +import librosa +import scipy.io as sio +import bisect +from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.networks import APC_encoder +# from funcs import utils +from pathlib import Path +from skimage.io import imread, imsave +from PIL import Image +import io +import cv2 +import h5py +import albumentations as A + +class BaseDataset(Dataset, ABC): + """This class is an abstract base class (ABC) for datasets. + + To create a subclass, you need to implement the following four functions: + -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). + -- <__len__>: return the size of dataset. + -- <__getitem__>: get a data point. + -- : (optionally) add dataset-specific options and set default options. + """ + + def __init__(self, opt, datasplit): + """Initialize the class; save the options in the class + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + self.opt = opt + self.root = opt.dataset_params['root'] + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + return 0 + + @abstractmethod + def __getitem__(self, index): + """Return a data point and its metadata information. + + Parameters: + index - - a random integer for data indexing + + Returns: + a dictionary of data with their names. It ususally contains the data itself and its metadata information. + """ + pass + + +class AudioVisualDataset(BaseDataset): + """ audio-visual dataset. currently, return 2D info and 3D tracking info. + + # for wavenet: + # |----receptive_field----| + # |--output_length--| + # example: | | | | | | | | | | | | | | | | | | | | | + # target: | | | | | | | | | | + + """ + + def __init__(self, opt, datasplit): + # save the option and dataset root + BaseDataset.__init__(self, opt, datasplit) + self.isTrain = self.opt['Train'] + + return + self.state = opt.dataset_type + self.dataset_name = opt.dataset_names + self.target_length = opt.time_frame_length + self.sample_rate = opt.sample_rate + self.fps = opt.FPS + + self.audioRF_history = opt.audioRF_history + self.audioRF_future = opt.audioRF_future + self.compute_mel_online = opt.compute_mel_online + self.feature_name = opt.feature_name + + self.audio_samples_one_frame = self.sample_rate / self.fps + self.frame_jump_stride = opt.frame_jump_stride + self.augment = False + self.task = opt.task + self.item_length_audio = int((self.audioRF_history + self.audioRF_future) / self.fps * self.sample_rate) + + if self.task == 'Audio2Feature': + if opt.feature_decoder == 'WaveNet': + self.A2L_receptive_field = opt.A2L_receptive_field + self.A2L_item_length = self.A2L_receptive_field + self.target_length - 1 + elif opt.feature_decoder == 'LSTM': + self.A2L_receptive_field = 30 + self.A2L_item_length = self.A2L_receptive_field + self.target_length - 1 + elif self.task == 'Audio2Headpose': + self.A2H_receptive_field = opt.A2H_receptive_field + self.A2H_item_length = self.A2H_receptive_field + self.target_length - 1 + self.audio_window = opt.audio_windows + self.half_audio_win = int(self.audio_window / 2) + + self.frame_future = opt.frame_future + self.predict_length = opt.predict_length + self.predict_len = int((self.predict_length - 1) / 2) + + self.gpu_ids = opt.gpu_ids + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + print('self.device:', self.device) + if self.task == 'Audio2Feature': + self.seq_len = opt.sequence_length + + self.total_len = 0 + self.dataset_root = os.path.join(self.root, self.dataset_name) + if self.state == 'Train': + self.clip_names = opt.train_dataset_names + elif self.state == 'Val': + self.clip_names = opt.validate_dataset_names + elif self.state == 'Test': + self.clip_names = opt.test_dataset_names + + self.clip_nums = len(self.clip_names) + # main info + self.audio = [''] * self.clip_nums + self.audio_features = [''] * self.clip_nums + self.feats = [''] * self.clip_nums + self.exps = [''] * self.clip_nums + self.pts3d = [''] * self.clip_nums + self.rot_angles = [''] * self.clip_nums + self.trans = [''] * self.clip_nums + self.headposes = [''] * self.clip_nums + self.velocity_pose = [''] * self.clip_nums + self.acceleration_pose = [''] * self.clip_nums + self.mean_trans = [''] * self.clip_nums + if self.state == 'Test': + self.landmarks = [''] * self.clip_nums + # meta info + self.start_point = [''] * self.clip_nums + self.end_point = [''] * self.clip_nums + self.len = [''] * self.clip_nums + self.sample_start = [] + self.clip_valid = ['True'] * self.clip_nums + self.invalid_clip = [] + + self.mouth_related_indices = np.concatenate([np.arange(4, 11), np.arange(46, 64)]) + if self.task == 'Audio2Feature': + if self.opt.only_mouth: + self.indices = self.mouth_related_indices + else: + self.indices = np.arange(73) + if opt.use_delta_pts: + self.pts3d_mean = np.load(os.path.join(self.dataset_root, 'mean_pts3d.npy')) + + for i in range(self.clip_nums): + name = self.clip_names[i] + clip_root = os.path.join(self.dataset_root, name) + # audio + if os.path.exists(os.path.join(clip_root, name + '_denoise.wav')): + audio_path = os.path.join(clip_root, name + '_denoise.wav') + print('find denoised wav!') + else: + audio_path = os.path.join(clip_root, name + '.wav') + self.audio[i], _ = librosa.load(audio_path, sr=self.sample_rate) + + if self.opt.audio_encoder == 'APC': + APC_name = os.path.split(self.opt.APC_model_path)[-1] + APC_feature_file = name + '_APC_feature_V0324_ckpt_{}.npy'.format(APC_name) + APC_feature_path = os.path.join(clip_root, APC_feature_file) + need_deepfeats = False if os.path.exists(APC_feature_path) else True + if not need_deepfeats: + self.audio_features[i] = np.load(APC_feature_path).astype(np.float32) + else: + need_deepfeats = False + + # 3D landmarks & headposes + if self.task == 'Audio2Feature': + self.start_point[i] = 0 + elif self.task == 'Audio2Headpose': + self.start_point[i] = 300 + fit_data_path = os.path.join(clip_root, '3d_fit_data.npz') + fit_data = np.load(fit_data_path) + if not opt.ispts_norm: + ori_pts3d = fit_data['pts_3d'].astype(np.float32) + else: + ori_pts3d = np.load(os.path.join(clip_root, 'tracked3D_normalized_pts_fix_contour.npy')) + if opt.use_delta_pts: + self.pts3d[i] = ori_pts3d - self.pts3d_mean + else: + self.pts3d[i] = ori_pts3d + if opt.feature_dtype == 'pts3d': + self.feats[i] = self.pts3d[i] + elif opt.feature_dtype == 'FW': + track_data_path = os.path.join(clip_root, 'tracking_results.mat') + self.feats[i] = sio.loadmat(track_data_path)['exps'].astype(np.float32) + self.rot_angles[i] = fit_data['rot_angles'].astype(np.float32) + # change -180~180 to 0~360 + if not self.dataset_name == 'Yuxuan': + rot_change = self.rot_angles[i][:, 0] < 0 + self.rot_angles[i][rot_change, 0] += 360 + self.rot_angles[i][:, 0] -= 180 # change x axis direction + # use delta translation + self.mean_trans[i] = fit_data['trans'][:, :, 0].astype(np.float32).mean(axis=0) + self.trans[i] = fit_data['trans'][:, :, 0].astype(np.float32) - self.mean_trans[i] + + self.headposes[i] = np.concatenate([self.rot_angles[i], self.trans[i]], axis=1) + self.velocity_pose[i] = np.concatenate( + [np.zeros(6)[None, :], self.headposes[i][1:] - self.headposes[i][:-1]]) + self.acceleration_pose[i] = np.concatenate( + [np.zeros(6)[None, :], self.velocity_pose[i][1:] - self.velocity_pose[i][:-1]]) + + if self.dataset_name == 'Yuxuan': + total_frames = self.feats[i].shape[0] - 300 - 130 + else: + total_frames = self.feats[i].shape[0] - 60 + + if need_deepfeats: + if self.opt.audio_encoder == 'APC': + print('dataset {} need to pre-compute APC features ...'.format(name)) + print('first we compute mel spectram for dataset {} '.format(name)) + mel80 = utils.compute_mel_one_sequence(self.audio[i]) + mel_nframe = mel80.shape[0] + print('loading pre-trained model: ', self.opt.APC_model_path) + APC_model = APC_encoder(self.opt.audiofeature_input_channels, + self.opt.APC_hidden_size, + self.opt.APC_rnn_layers, + self.opt.APC_residual) + APC_model.load_state_dict(torch.load(self.opt.APC_model_path, map_location=str(self.device)), + strict=False) + # APC_model.load_state_dict(torch.load(self.opt.APC_model_path), strict=False) + APC_model.cuda() + APC_model.eval() + with torch.no_grad(): + length = torch.Tensor([mel_nframe]) + # hidden_reps = torch.zeros([mel_nframe, self.opt.APC_hidden_size]).cuda() + mel80_torch = torch.from_numpy(mel80.astype(np.float32)).cuda().unsqueeze(0) + hidden_reps = APC_model.forward(mel80_torch, length)[0] # [mel_nframe, 512] + hidden_reps = hidden_reps.cpu().numpy() + np.save(APC_feature_path, hidden_reps) + self.audio_features[i] = hidden_reps + + valid_frames = total_frames - self.start_point[i] + self.len[i] = valid_frames - 400 + if i == 0: + self.sample_start.append(0) + else: + self.sample_start.append(self.sample_start[-1] + self.len[i - 1] - 1) + self.total_len += np.int32(np.floor(self.len[i] / self.frame_jump_stride)) + + def __getitem__(self, index): + # recover real index from compressed one + index_real = np.int32(index * self.frame_jump_stride) + # find which audio file and the start frame index + file_index = bisect.bisect_right(self.sample_start, index_real) - 1 + current_frame = index_real - self.sample_start[file_index] + self.start_point[file_index] + current_target_length = self.target_length + + if self.task == 'Audio2Feature': + # start point is current frame + A2Lsamples = self.audio_features[file_index][current_frame * 2: (current_frame + self.seq_len) * 2] + target_pts3d = self.feats[file_index][current_frame: current_frame + self.seq_len, self.indices].reshape( + self.seq_len, -1) + + A2Lsamples = torch.from_numpy(A2Lsamples).float() + target_pts3d = torch.from_numpy(target_pts3d).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Lsamples, target_pts3d + + + elif self.task == 'Audio2Headpose': + if self.opt.feature_decoder == 'WaveNet': + # find the history info start points + A2H_history_start = current_frame - self.A2H_receptive_field + A2H_item_length = self.A2H_item_length + A2H_receptive_field = self.A2H_receptive_field + + if self.half_audio_win == 1: + A2Hsamples = self.audio_features[file_index][2 * (A2H_history_start + self.frame_future): 2 * ( + A2H_history_start + self.frame_future + A2H_item_length)] + else: + A2Hsamples = np.zeros([A2H_item_length, self.audio_window, 512]) + for i in range(A2H_item_length): + A2Hsamples[i] = self.audio_features[file_index][ + 2 * (A2H_history_start + i) - self.half_audio_win: 2 * ( + A2H_history_start + i) + self.half_audio_win] + + if self.predict_len == 0: + target_headpose = self.headposes[file_index][ + A2H_history_start + A2H_receptive_field: A2H_history_start + A2H_item_length + 1] + history_headpose = self.headposes[file_index][ + A2H_history_start: A2H_history_start + A2H_item_length].reshape(A2H_item_length, + -1) + + target_velocity = self.velocity_pose[file_index][ + A2H_history_start + A2H_receptive_field: A2H_history_start + A2H_item_length + 1] + history_velocity = self.velocity_pose[file_index][ + A2H_history_start: A2H_history_start + A2H_item_length].reshape(A2H_item_length, + -1) + target_info = torch.from_numpy( + np.concatenate([target_headpose, target_velocity], axis=1).reshape(current_target_length, + -1)).float() + else: + history_headpose = self.headposes[file_index][ + A2H_history_start: A2H_history_start + A2H_item_length].reshape(A2H_item_length, + -1) + history_velocity = self.velocity_pose[file_index][ + A2H_history_start: A2H_history_start + A2H_item_length].reshape(A2H_item_length, + -1) + + target_headpose_ = self.headposes[file_index][ + A2H_history_start + A2H_receptive_field - self.predict_len: A2H_history_start + A2H_item_length + 1 + self.predict_len + 1] + target_headpose = np.zeros([current_target_length, self.predict_length, target_headpose_.shape[1]]) + for i in range(current_target_length): + target_headpose[i] = target_headpose_[i: i + self.predict_length] + target_headpose = target_headpose # .reshape(current_target_length, -1, order='F') + + target_velocity_ = self.headposes[file_index][ + A2H_history_start + A2H_receptive_field - self.predict_len: A2H_history_start + A2H_item_length + 1 + self.predict_len + 1] + target_velocity = np.zeros([current_target_length, self.predict_length, target_velocity_.shape[1]]) + for i in range(current_target_length): + target_velocity[i] = target_velocity_[i: i + self.predict_length] + target_velocity = target_velocity # .reshape(current_target_length, -1, order='F') + + target_info = torch.from_numpy( + np.concatenate([target_headpose, target_velocity], axis=2).reshape(current_target_length, + -1)).float() + + A2Hsamples = torch.from_numpy(A2Hsamples).float() + + history_info = torch.from_numpy(np.concatenate([history_headpose, history_velocity], axis=1)).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Hsamples, history_info, target_info + + + elif self.opt.feature_decoder == 'LSTM': + A2Hsamples = self.audio_features[file_index][ + current_frame * 2: (current_frame + self.opt.A2H_receptive_field) * 2] + + target_headpose = self.headposes[file_index][ + current_frame: current_frame + self.opt.A2H_receptive_field] + target_velocity = self.velocity_pose[file_index][ + current_frame: current_frame + self.opt.A2H_receptive_field] + target_info = torch.from_numpy( + np.concatenate([target_headpose, target_velocity], axis=1).reshape(self.opt.A2H_receptive_field, + -1)).float() + + A2Hsamples = torch.from_numpy(A2Hsamples).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Hsamples, target_info + + def __len__(self): + return self.total_len + + +class FaceDataset(BaseDataset): + def __init__(self, opt, datasplit): + """Initialize this dataset class. + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + + A few things can be done here. + - save the options (have been done in BaseDataset) + - get image paths and meta information of the dataset. + - define the image transformation. + """ + BaseDataset.__init__(self, opt, datasplit) + + + self.state = 'Train' if self.opt.train else 'Test' + return + self.dataset_name = opt.dataset_names[0] + # default settings + # currently, we have 8 parts for face parts + self.part_list = [[list(range(0, 15))], # contour + [[15, 16, 17, 18, 18, 19, 20, 15]], # right eyebrow + [[21, 22, 23, 24, 24, 25, 26, 21]], # left eyebrow + [range(35, 44)], # nose + [[27, 65, 28, 68, 29], [29, 67, 30, 66, 27]], # right eye + [[33, 69, 32, 72, 31], [31, 71, 34, 70, 33]], # left eye + [range(46, 53), [52, 53, 54, 55, 56, 57, 46]], # mouth + [[46, 63, 62, 61, 52], [52, 60, 59, 58, 46]] # tongue + ] + self.mouth_outer = [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 46] + self.label_list = [1, 1, 2, 3, 3, 4, 5] # labeling for different facial parts + + # only load in train mode + + self.dataset_root = os.path.join(self.root, self.dataset_name) + if self.state == 'Train': + self.clip_names = opt.train_dataset_names + elif self.state == 'Val': + self.clip_names = opt.validate_dataset_names + elif self.state == 'Test': + self.clip_names = opt.test_dataset_names + + self.clip_nums = len(self.clip_names) + + # load pts & image info + self.landmarks2D, self.len, self.sample_len = [''] * self.clip_nums, [''] * self.clip_nums, [ + ''] * self.clip_nums + self.image_transforms, self.image_pad, self.tgts_paths = [''] * self.clip_nums, [''] * self.clip_nums, [ + ''] * self.clip_nums + self.shoulders, self.shoulder3D = [''] * self.clip_nums, [''] * self.clip_nums + self.sample_start = [] + + # tracked 3d info & candidates images + self.pts3d, self.rot, self.trans = [''] * self.clip_nums, [''] * self.clip_nums, [''] * self.clip_nums + self.full_cand = [''] * self.clip_nums + self.headposes = [''] * self.clip_nums + + self.total_len = 0 + if self.opt.isTrain: + for i in range(self.clip_nums): + name = self.clip_names[i] + clip_root = os.path.join(self.dataset_root, name) + # basic image info + img_file_path = os.path.join(clip_root, name + '.h5') + img_file = h5py.File(img_file_path, 'r')[name] + example = np.asarray(Image.open(io.BytesIO(img_file[0]))) + h, w, _ = example.shape + + landmark_path = os.path.join(clip_root, 'tracked2D_normalized_pts_fix_contour.npy') + self.landmarks2D[i] = np.load(landmark_path).astype(np.float32) + change_paras = np.load(os.path.join(clip_root, 'change_paras.npz')) + scale, xc, yc = change_paras['scale'], change_paras['xc'], change_paras['yc'] + x_min, x_max, y_min, y_max = xc - 256, xc + 256, yc - 256, yc + 256 + # if need padding + x_min, x_max, y_min, y_max, self.image_pad[i] = max(x_min, 0), min(x_max, w), max(y_min, 0), min(y_max, + h), None + + if x_min == 0 or x_max == 512 or y_min == 0 or y_max == 512: + top, bottom, left, right = abs(yc - 256 - y_min), abs(yc + 256 - y_max), abs(xc - 256 - x_min), abs( + xc + 256 - x_max) + self.image_pad[i] = [top, bottom, left, right] + self.image_transforms[i] = A.Compose([ + A.Resize(np.int32(h * scale), np.int32(w * scale)), + A.Crop(x_min, y_min, x_max, y_max)]) + + if self.opt.isH5: + tgt_file_path = os.path.join(clip_root, name + '.h5') + tgt_file = h5py.File(tgt_file_path, 'r')[name] + image_length = len(tgt_file) + else: + tgt_paths = list(map(lambda x: str(x), sorted(list(Path(clip_root).glob('*' + self.opt.suffix)), + key=lambda x: int(x.stem)))) + image_length = len(tgt_paths) + self.tgts_paths[i] = tgt_paths + if not self.landmarks2D[i].shape[0] == image_length: + raise ValueError('In dataset {} length of landmarks and images are not equal!'.format(name)) + + # tracked 3d info + fit_data_path = os.path.join(clip_root, '3d_fit_data.npz') + fit_data = np.load(fit_data_path) + self.pts3d[i] = fit_data['pts_3d'].astype(np.float32) + self.rot[i] = fit_data['rot_angles'].astype(np.float32) + self.trans[i] = fit_data['trans'][:, :, 0].astype(np.float32) + if not self.pts3d[i].shape[0] == image_length: + raise ValueError('In dataset {} length of 3d pts and images are not equal!'.format(name)) + + # candidates images + + tmp = [] + for j in range(4): + try: + output = imread(os.path.join(clip_root, 'candidates', f'normalized_full_{j}.jpg')) + except: + imgc = imread(os.path.join(clip_root, 'candidates', f'full_{j}.jpg')) + output = self.common_dataset_transform(imgc, i) + imsave(os.path.join(clip_root, 'candidates', f'normalized_full_{j}.jpg'), output) + output = A.pytorch.transforms.ToTensor(normalize={'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5)})( + image=output)['image'] + tmp.append(output) + self.full_cand[i] = torch.cat(tmp) + + # headpose + fit_data_path = os.path.join(clip_root, '3d_fit_data.npz') + fit_data = np.load(fit_data_path) + rot_angles = fit_data['rot_angles'].astype(np.float32) + # change -180~180 to 0~360 + if not self.dataset_name == 'Yuxuan': + rot_change = rot_angles[:, 0] < 0 + rot_angles[rot_change, 0] += 360 + rot_angles[:, 0] -= 180 # change x axis direction + # use delta translation + mean_trans = fit_data['trans'][:, :, 0].astype(np.float32).mean(axis=0) + trans = fit_data['trans'][:, :, 0].astype(np.float32) - mean_trans + + self.headposes[i] = np.concatenate([rot_angles, trans], axis=1) + + # shoulders + shoulder_path = os.path.join(clip_root, 'normalized_shoulder_points.npy') + self.shoulders[i] = np.load(shoulder_path) + shoulder3D_path = os.path.join(clip_root, 'shoulder_points3D.npy') + self.shoulder3D[i] = np.load(shoulder3D_path) + + self.sample_len[i] = np.int32(np.floor((self.landmarks2D[i].shape[0] - 60) / self.opt.frame_jump) + 1) + self.len[i] = self.landmarks2D[i].shape[0] + if i == 0: + self.sample_start.append(0) + else: + self.sample_start.append(self.sample_start[-1] + self.sample_len[i - 1]) # not minus 1 + self.total_len += self.sample_len[i] + + # test mode + else: + # if need padding + example = imread(os.path.join(self.root, 'example.png')) + h, w, _ = example.shape + change_paras = np.load(os.path.join(self.root, 'change_paras.npz')) + scale, xc, yc = change_paras['scale'], change_paras['xc'], change_paras['yc'] + x_min, x_max, y_min, y_max = xc - 256, xc + 256, yc - 256, yc + 256 + x_min, x_max, y_min, y_max, self.image_pad = max(x_min, 0), min(x_max, w), max(y_min, 0), min(y_max, + h), None + + if x_min == 0 or x_max == 512 or y_min == 0 or y_max == 512: + top, bottom, left, right = abs(yc - 256 - y_min), abs(yc + 256 - y_max), abs(xc - 256 - x_min), abs( + xc + 256 - x_max) + self.image_pad = [top, bottom, left, right] + + def __getitem__(self, ind): + dataset_index = bisect.bisect_right(self.sample_start, ind) - 1 + data_index = (ind - self.sample_start[dataset_index]) * self.opt.frame_jump + np.random.randint( + self.opt.frame_jump) + + target_ind = data_index + 1 # history_ind, current_ind + landmarks = self.landmarks2D[dataset_index][target_ind] # [73, 2] + shoulders = self.shoulders[dataset_index][target_ind].copy() + + dataset_name = self.clip_names[dataset_index] + clip_root = os.path.join(self.dataset_root, dataset_name) + if self.opt.isH5: + tgt_file_path = os.path.join(clip_root, dataset_name + '.h5') + tgt_file = h5py.File(tgt_file_path, 'r')[dataset_name] + tgt_image = np.asarray(Image.open(io.BytesIO(tgt_file[target_ind]))) + + # do transform + tgt_image = self.common_dataset_transform(tgt_image, dataset_index, None) + else: + pass + + h, w, _ = tgt_image.shape + + ### transformations & online data augmentations on images and landmarks + self.get_crop_coords(landmarks, (w, h), dataset_name, + random_trans_scale=0) # 30.5 µs ± 348 ns random translation + + transform_tgt = self.get_transform(dataset_name, True, n_img=1, n_keypoint=1, flip=False) + transformed_tgt = transform_tgt(image=tgt_image, keypoints=landmarks) + + tgt_image, points = transformed_tgt['image'], np.array(transformed_tgt['keypoints']).astype(np.float32) + + feature_map = self.get_feature_image(points, (self.opt.loadSize, self.opt.loadSize), shoulders, + self.image_pad[dataset_index])[np.newaxis, :].astype(np.float32) / 255. + feature_map = torch.from_numpy(feature_map) + + ## facial weight mask + weight_mask = self.generate_facial_weight_mask(points, h, w)[None, :] + + cand_image = self.full_cand[dataset_index] + + return_list = {'feature_map': feature_map, 'cand_image': cand_image, 'tgt_image': tgt_image, + 'weight_mask': weight_mask} + + return return_list + + def common_dataset_transform(self, input, i): + output = self.image_transforms[i](image=input)['image'] + if self.image_pad[i] is not None: + top, bottom, left, right = self.image_pad[i] + output = cv2.copyMakeBorder(output, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0) + return output + + def generate_facial_weight_mask(self, points, h=512, w=512): + mouth_mask = np.zeros([512, 512, 1]) + points = points[self.mouth_outer] + points = np.int32(points) + mouth_mask = cv2.fillPoly(mouth_mask, [points], (255, 0, 0)) + # plt.imshow(mouth_mask[:,:,0]) + mouth_mask = cv2.dilate(mouth_mask, np.ones((45, 45))) / 255 + + return mouth_mask.astype(np.float32) + + def get_transform(self, dataset_name, keypoints=False, n_img=1, n_keypoint=1, normalize=True, flip=False): + min_x = getattr(self, 'min_x_' + str(dataset_name)) + max_x = getattr(self, 'max_x_' + str(dataset_name)) + min_y = getattr(self, 'min_y_' + str(dataset_name)) + max_y = getattr(self, 'max_y_' + str(dataset_name)) + + additional_flag = False + additional_targets_dict = {} + if n_img > 1: + additional_flag = True + image_str = ['image' + str(i) for i in range(0, n_img)] + for i in range(n_img): + additional_targets_dict[image_str[i]] = 'image' + if n_keypoint > 1: + additional_flag = True + keypoint_str = ['keypoint' + str(i) for i in range(0, n_keypoint)] + for i in range(n_keypoint): + additional_targets_dict[keypoint_str[i]] = 'keypoints' + + transform = A.Compose([ + A.Crop(x_min=min_x, x_max=max_x, y_min=min_y, y_max=max_y), + A.Resize(self.opt.loadSize, self.opt.loadSize), + A.HorizontalFlip(p=flip), + A.pytorch.transforms.ToTensor( + normalize={'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5)} if normalize == True else None)], + keypoint_params=A.KeypointParams(format='xy', remove_invisible=False) if keypoints == True else None, + additional_targets=additional_targets_dict if additional_flag == True else None + ) + return transform + + def get_data_test_mode(self, landmarks, shoulder, pad=None): + ''' get transformed data + ''' + + feature_map = torch.from_numpy( + self.get_feature_image(landmarks, (self.opt.loadSize, self.opt.loadSize), shoulder, pad)[np.newaxis, + :].astype(np.float32) / 255.) + + return feature_map + + def get_feature_image(self, landmarks, size, shoulders=None, image_pad=None): + # draw edges + im_edges = self.draw_face_feature_maps(landmarks, size) + if shoulders is not None: + if image_pad is not None: + top, bottom, left, right = image_pad + delta_y = top - bottom + delta_x = right - left + shoulders[:, 0] += delta_x + shoulders[:, 1] += delta_y + im_edges = self.draw_shoulder_points(im_edges, shoulders) + + return im_edges + + def draw_shoulder_points(self, img, shoulder_points): + num = int(shoulder_points.shape[0] / 2) + for i in range(2): + for j in range(num - 1): + pt1 = [int(flt) for flt in shoulder_points[i * num + j]] + pt2 = [int(flt) for flt in shoulder_points[i * num + j + 1]] + img = cv2.line(img, tuple(pt1), tuple(pt2), 255, 2) # BGR + + return img + + def draw_face_feature_maps(self, keypoints, size=(512, 512)): + w, h = size + # edge map for face region from keypoints + im_edges = np.zeros((h, w), np.uint8) # edge map for all edges + for edge_list in self.part_list: + for edge in edge_list: + for i in range(len(edge) - 1): + pt1 = [int(flt) for flt in keypoints[edge[i]]] + pt2 = [int(flt) for flt in keypoints[edge[i + 1]]] + im_edges = cv2.line(im_edges, tuple(pt1), tuple(pt2), 255, 2) + + return im_edges + + def get_crop_coords(self, keypoints, size, dataset_name, random_trans_scale=50): + # cut a rought region for fine cutting + # here x towards right and y towards down, origin is left-up corner + w_ori, h_ori = size + min_y, max_y = keypoints[:, 1].min(), keypoints[:, 1].max() + min_x, max_x = keypoints[:, 0].min(), keypoints[:, 0].max() + xc = (min_x + max_x) // 2 + yc = (min_y * 3 + max_y) // 4 + h = w = min((max_x - min_x) * 2, w_ori, h_ori) + + if self.opt.isTrain: + # do online augment on landmarks & images + # 1. random translation: move 10% + x_bias, y_bias = np.random.uniform(-random_trans_scale, random_trans_scale, size=(2,)) + xc, yc = xc + x_bias, yc + y_bias + + # modify the center x, center y to valid position + xc = min(max(0, xc - w // 2) + w, w_ori) - w // 2 + yc = min(max(0, yc - h // 2) + h, h_ori) - h // 2 + + min_x, max_x = xc - w // 2, xc + w // 2 + min_y, max_y = yc - h // 2, yc + h // 2 + + setattr(self, 'min_x_' + str(dataset_name), int(min_x)) + setattr(self, 'max_x_' + str(dataset_name), int(max_x)) + setattr(self, 'min_y_' + str(dataset_name), int(min_y)) + setattr(self, 'max_y_' + str(dataset_name), int(max_y)) + + def crop(self, img, dataset_name): + min_x = getattr(self, 'min_x_' + str(dataset_name)) + max_x = getattr(self, 'max_x_' + str(dataset_name)) + min_y = getattr(self, 'min_y_' + str(dataset_name)) + max_y = getattr(self, 'max_y_' + str(dataset_name)) + if isinstance(img, np.ndarray): + return img[min_y:max_y, min_x:max_x] + else: + return img.crop((min_x, min_y, max_x, max_y)) + + def __len__(self): + if self.opt.train: + return self.total_len + else: + return 1 + + def name(self): + return 'FaceDataset' + + +class live_speech_portraitsDataset(Dataset): + + def __init__(self, config, datasplit): + + self.opt = config + self.split = datasplit + + @abstractmethod + def __getitem__(self, index): + + # recover real index from compressed one + index_real = np.int32(index * self.frame_jump_stride) + # find which audio file and the start frame index + file_index = bisect.bisect_right(self.sample_start, index_real) - 1 + current_frame = index_real - self.sample_start[file_index] + self.start_point[file_index] + current_target_length = self.target_length + + if self.task == 'Audio2Feature': + # start point is current frame + A2Lsamples = self.audio_features[file_index][current_frame * 2: (current_frame + self.seq_len) * 2] + target_pts3d = self.feats[file_index][current_frame: current_frame + self.seq_len, self.indices].reshape( + self.seq_len, -1) + + A2Lsamples = torch.from_numpy(A2Lsamples).float() + target_pts3d = torch.from_numpy(target_pts3d).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Lsamples, target_pts3d + + + elif self.task == 'Audio2Headpose': + if self.opt.feature_decoder == 'WaveNet': + # find the history info start points + A2H_history_start = current_frame - self.A2H_receptive_field + A2H_item_length = self.A2H_item_length + A2H_receptive_field = self.A2H_receptive_field + + if self.half_audio_win == 1: + A2Hsamples = self.audio_features[file_index][2 * (A2H_history_start + self.frame_future): 2 * ( + A2H_history_start + self.frame_future + A2H_item_length)] + else: + A2Hsamples = np.zeros([A2H_item_length, self.audio_window, 512]) + for i in range(A2H_item_length): + A2Hsamples[i] = self.audio_features[file_index][ + 2 * (A2H_history_start + i) - self.half_audio_win: 2 * ( + A2H_history_start + i) + self.half_audio_win] + + if self.predict_len == 0: + target_headpose = self.headposes[file_index][ + A2H_history_start + A2H_receptive_field: A2H_history_start + A2H_item_length + 1] + history_headpose = self.headposes[file_index][ + A2H_history_start: A2H_history_start + A2H_item_length].reshape(A2H_item_length, + -1) + + target_velocity = self.velocity_pose[file_index][ + A2H_history_start + A2H_receptive_field: A2H_history_start + A2H_item_length + 1] + history_velocity = self.velocity_pose[file_index][ + A2H_history_start: A2H_history_start + A2H_item_length].reshape(A2H_item_length, + -1) + target_info = torch.from_numpy( + np.concatenate([target_headpose, target_velocity], axis=1).reshape(current_target_length, + -1)).float() + else: + history_headpose = self.headposes[file_index][ + A2H_history_start: A2H_history_start + A2H_item_length].reshape(A2H_item_length, + -1) + history_velocity = self.velocity_pose[file_index][ + A2H_history_start: A2H_history_start + A2H_item_length].reshape(A2H_item_length, + -1) + + target_headpose_ = self.headposes[file_index][ + A2H_history_start + A2H_receptive_field - self.predict_len: A2H_history_start + A2H_item_length + 1 + self.predict_len + 1] + target_headpose = np.zeros([current_target_length, self.predict_length, target_headpose_.shape[1]]) + for i in range(current_target_length): + target_headpose[i] = target_headpose_[i: i + self.predict_length] + target_headpose = target_headpose # .reshape(current_target_length, -1, order='F') + + target_velocity_ = self.headposes[file_index][ + A2H_history_start + A2H_receptive_field - self.predict_len: A2H_history_start + A2H_item_length + 1 + self.predict_len + 1] + target_velocity = np.zeros([current_target_length, self.predict_length, target_velocity_.shape[1]]) + for i in range(current_target_length): + target_velocity[i] = target_velocity_[i: i + self.predict_length] + target_velocity = target_velocity # .reshape(current_target_length, -1, order='F') + + target_info = torch.from_numpy( + np.concatenate([target_headpose, target_velocity], axis=2).reshape(current_target_length, + -1)).float() + + A2Hsamples = torch.from_numpy(A2Hsamples).float() + + history_info = torch.from_numpy(np.concatenate([history_headpose, history_velocity], axis=1)).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Hsamples, history_info, target_info + + + elif self.opt.feature_decoder == 'LSTM': + A2Hsamples = self.audio_features[file_index][ + current_frame * 2: (current_frame + self.opt.A2H_receptive_field) * 2] + + target_headpose = self.headposes[file_index][ + current_frame: current_frame + self.opt.A2H_receptive_field] + target_velocity = self.velocity_pose[file_index][ + current_frame: current_frame + self.opt.A2H_receptive_field] + target_info = torch.from_numpy( + np.concatenate([target_headpose, target_velocity], axis=1).reshape(self.opt.A2H_receptive_field, + -1)).float() + + A2Hsamples = torch.from_numpy(A2Hsamples).float() + + # [item_length, mel_channels, mel_width], or [item_length, APC_hidden_size] + return A2Hsamples, target_info + + @abstractmethod + def __len__(self): + """Return the total number of images in the dataset.""" + if self.opt.train: + return self.total_len + else: + return 1 + + def modify_commandline_options(parser, is_train): + return parser diff --git a/talkingface/data/dataset/wav2lip_dataset.py b/talkingface/data/dataset/wav2lip_dataset.py index e52ae28d..cb79402a 100644 --- a/talkingface/data/dataset/wav2lip_dataset.py +++ b/talkingface/data/dataset/wav2lip_dataset.py @@ -1,7 +1,8 @@ from os.path import dirname, join, basename, isfile from tqdm import tqdm from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio -import python_speech_features +# 下面这个库似乎没用到,又无法安装,此处我直接注释掉了 +# import python_speech_features import torch from torch import nn diff --git a/talkingface/evaluator/__init__.py b/talkingface/evaluator/__init__.py index d5eeefda..a30a5766 100644 --- a/talkingface/evaluator/__init__.py +++ b/talkingface/evaluator/__init__.py @@ -1,4 +1,4 @@ -from talkingface.evaluator.metric_models import * from talkingface.evaluator.metrics import * +from talkingface.evaluator.metric_models import * from talkingface.evaluator.register import * from talkingface.evaluator.evaluator import * \ No newline at end of file diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/__init__.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/__init__.py new file mode 100644 index 00000000..761f9524 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/__init__.py @@ -0,0 +1,142 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" +import os +import importlib +import numpy as np +import torch +import torch.nn as nn +from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.base_model import BaseModel + + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "talkingface.model.audio_driven_talkingface.LiveSpeechPortraits." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit() + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: +# >>> from models import create_model +# >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model_name) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance + + +def save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, modelG, modelD, end_of_epoch=False): + if not end_of_epoch: + if total_steps % opt.save_latest_freq == 0: + visualizer.vis_print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) + modelG.module.save('latest') + modelD.module.save('latest') + np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') + else: + if epoch % opt.save_epoch_freq == 0: + visualizer.vis_print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) + modelG.module.save('latest') + modelD.module.save('latest') + modelG.module.save(epoch) + modelD.module.save(epoch) + np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') + + +def update_models(opt, epoch, modelG, modelD, dataset_warp): + ### linearly decay learning rate after certain iterations + if epoch > opt.niter: + modelG.module.update_learning_rate(epoch, 'G') + modelD.module.update_learning_rate(epoch, 'D') + + ### gradually grow training sequence length + if (epoch % opt.niter_step) == 0: + dataset_warp.dataset.update_training_batch(epoch//opt.niter_step) +# modelG.module.update_training_batch(epoch//opt.niter_step) + + ### finetune all scales + if (opt.n_scales_spatial > 1) and (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): + modelG.module.update_fixed_params() + + +class myModel(nn.Module): + def __init__(self, opt, model): + super(myModel, self).__init__() + self.opt = opt + self.module = model + self.model = nn.DataParallel(model, device_ids=opt.gpu_ids) + self.bs_per_gpu = int(np.ceil(float(opt.batch_size) / len(opt.gpu_ids))) # batch size for each GPU + self.pad_bs = self.bs_per_gpu * len(opt.gpu_ids) - opt.batch_size + + def forward(self, *inputs, **kwargs): + inputs = self.add_dummy_to_tensor(inputs, self.pad_bs) + outputs = self.model(*inputs, **kwargs, dummy_bs=self.pad_bs) + if self.pad_bs == self.bs_per_gpu: # gpu 0 does 0 batch but still returns 1 batch + return self.remove_dummy_from_tensor(outputs, 1) + return outputs + + def add_dummy_to_tensor(self, tensors, add_size=0): + if add_size == 0 or tensors is None: return tensors + if type(tensors) == list or type(tensors) == tuple: + return [self.add_dummy_to_tensor(tensor, add_size) for tensor in tensors] + + if isinstance(tensors, torch.Tensor): + dummy = torch.zeros_like(tensors)[:add_size] + tensors = torch.cat([dummy, tensors]) + return tensors + + def remove_dummy_from_tensor(self, tensors, remove_size=0): + if remove_size == 0 or tensors is None: return tensors + if type(tensors) == list or type(tensors) == tuple: + return [self.remove_dummy_from_tensor(tensor, remove_size) for tensor in tensors] + + if isinstance(tensors, torch.Tensor): + tensors = tensors[remove_size:] + return tensors + + diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2feature.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2feature.py new file mode 100644 index 00000000..91f81755 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2feature.py @@ -0,0 +1,84 @@ +import torch.nn as nn +from .networks import WaveNet + + + +class Audio2Feature(nn.Module): + def __init__(self, opt): + super(Audio2Feature, self).__init__() + self.opt = opt + opt.A2L_wavenet_input_channels = opt.APC_hidden_size + if self.opt.loss == 'GMM': + output_size = (2 * opt.A2L_GMM_ndim + 1) * opt.A2L_GMM_ncenter + elif self.opt.loss == 'L2': + num_pred = opt.predict_length + output_size = opt.A2L_GMM_ndim * num_pred + # define networks + if opt.feature_decoder == 'WaveNet': + self.WaveNet = WaveNet(opt.A2L_wavenet_residual_layers, + opt.A2L_wavenet_residual_blocks, + opt.A2L_wavenet_residual_channels, + opt.A2L_wavenet_dilation_channels, + opt.A2L_wavenet_skip_channels, + opt.A2L_wavenet_kernel_size, + opt.time_frame_length, + opt.A2L_wavenet_use_bias, + opt.A2L_wavenet_cond, + opt.A2L_wavenet_input_channels, + opt.A2L_GMM_ncenter, + opt.A2L_GMM_ndim, + output_size) + self.item_length = self.WaveNet.receptive_field + opt.time_frame_length - 1 + elif opt.feature_decoder == 'LSTM': + self.downsample = nn.Sequential( + nn.Linear(in_features=opt.APC_hidden_size * 2, out_features=opt.APC_hidden_size), + nn.BatchNorm1d(opt.APC_hidden_size), + nn.LeakyReLU(0.2), + nn.Linear(opt.APC_hidden_size, opt.APC_hidden_size), + ) + self.LSTM = nn.LSTM(input_size=opt.APC_hidden_size, + hidden_size=256, + num_layers=3, + dropout=0, + bidirectional=False, + batch_first=True) + self.fc = nn.Sequential( + nn.Linear(in_features=256, out_features=512), + nn.BatchNorm1d(512), + nn.LeakyReLU(0.2), + nn.Linear(512, 512), + nn.BatchNorm1d(512), + nn.LeakyReLU(0.2), + nn.Linear(512, output_size)) + + + def forward(self, audio_features): + ''' + Args: + audio_features: [b, T, ndim] + ''' + if self.opt.feature_decoder == 'WaveNet': + pred = self.WaveNet.forward(audio_features.permute(0,2,1)) + elif self.opt.feature_decoder == 'LSTM': + bs, item_len, ndim = audio_features.shape + # new in 0324 + audio_features = audio_features.reshape(bs, -1, ndim*2) + down_audio_feats = self.downsample(audio_features.reshape(-1, ndim*2)).reshape(bs, int(item_len/2), ndim) + output, (hn, cn) = self.LSTM(down_audio_feats) +# output, (hn, cn) = self.LSTM(audio_features) + pred = self.fc(output.reshape(-1, 256)).reshape(bs, int(item_len/2), -1) +# pred = self.fc(output.reshape(-1, 256)).reshape(bs, item_len, -1)[:, -self.opt.time_frame_length:, :] + + return pred + + + + + + + + + + + + diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2feature_model.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2feature_model.py new file mode 100644 index 00000000..d029fdf3 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2feature_model.py @@ -0,0 +1,157 @@ +import numpy as np +import torch + +from .base_model import BaseModel +from . import networks +from . import audio2feature + + + + + +class Audio2FeatureModel(BaseModel): + def __init__(self, opt): + """Initialize the Audio2Feature class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + + # specify the models you want to save to the disk. The training/test scripts will call and + # define networks + self.model_names = ['Audio2Feature'] + self.Audio2Feature = networks.init_net(audio2feature.Audio2Feature(opt), init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) + + # define only during training time + if self.isTrain: + # losses + self.featureL2loss = torch.nn.MSELoss().to(self.device) + # optimizer + self.optimizer = torch.optim.Adam([{'params':self.Audio2Feature.parameters(), + 'initial_lr': opt.lr}], lr=opt.lr, betas=(0.9, 0.99)) + + self.optimizers.append(self.optimizer) + + if opt.continue_train: + self.resume_training() + + + def resume_training(self): + opt = self.opt + ### if continue training, recover previous states + print('Resuming from epoch %s ' % (opt.load_epoch)) + # change epoch count & update schedule settings + opt.epoch_count = int(opt.load_epoch) + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + # print lerning rate + lr = self.optimizers[0].param_groups[0]['lr'] + print('update learning rate: {} -> {}'.format(opt.lr, lr)) + + + + def set_input(self, data, data_info=None): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + """ + + self.audio_feats, self.target_info = data +# b, item_length, mel_channels, width = self.audio_feats.shape + self.audio_feats = self.audio_feats.to(self.device) + self.target_info = self.target_info.to(self.device) + + # gaussian noise +# if self.opt.gaussian_noise: +# self.audio_feats = self.opt.gaussian_noise_scale * torch.randn(self.audio_feats.shape).cuda() +# self.target_info += self.opt.gaussian_noise_scale * torch.randn(self.target_info.shape).cuda() + + + + def forward(self): + ''' + Args: + history_landmarks: [b, T, ndim] + audio_features: [b, 1, nfeas, nwins] + Returns: + preds: [b, T, output_channels] + ''' + self.preds = self.Audio2Feature.forward(self.audio_feats) + + + + def calculate_loss(self): + """ calculate loss in detail, only forward pass included""" + if self.opt.loss == 'GMM': + b, T, _ = self.target_info.shape + self.loss_GMM = self.criterion_GMM(self.preds, self.target_info) + self.loss = self.loss_GMM + + elif self.opt.loss == 'L2': + frame_future = self.opt.frame_future + if not frame_future == 0: + self.loss = self.featureL2loss(self.preds[:, frame_future:], self.target_info[:, :-frame_future]) * 1000 + else: + self.loss = self.featureL2loss(self.preds, self.target_info) * 1000 + + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + self.calculate_loss() + self.loss.backward() + + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.optimizer.zero_grad() # clear optimizer parameters grad + self.forward() # forward pass + self.backward() # calculate loss and gradients + self.optimizer.step() # update gradients + + + def validate(self): + """ validate process """ + with torch.no_grad(): + self.forward() + self.calculate_loss() + + + def generate_sequences(self, audio_feats, sample_rate = 16000, fps=60, fill_zero=True, opt=[]): + ''' generate landmark sequences given audio and a initialized landmark. + Note that the audio input should have the same sample rate as the training. + Args: + audio_sequences: [n,], in numpy + init_landmarks: [npts, 2], in numpy + sample_rate: audio sample rate, should be same as training process. + method(str): optional, how to generate the sequence, indeed it is the + loss function during training process. Options are 'L2' or 'GMM'. + Reutrns: + landmark_sequences: [T, npts, 2] predition landmark sequences + ''' + + frame_future = opt.frame_future + nframe = int(audio_feats.shape[0] / 2) + + if not frame_future == 0: + audio_feats_insert = np.repeat(audio_feats[-1], 2 * (frame_future)).reshape(-1, 2 * (frame_future)).T + audio_feats = np.concatenate([audio_feats, audio_feats_insert]) + + + # evaluate mode + self.Audio2Feature.eval() + + with torch.no_grad(): + input = torch.from_numpy(audio_feats).unsqueeze(0).float().to(self.device) + preds = self.Audio2Feature.forward(input) + + # drop first frame future results + if not frame_future == 0: + preds = preds[0, frame_future:].cpu().detach().numpy() + else: + preds = preds[0, :].cpu().detach().numpy() + + assert preds.shape[0] == nframe + + + return preds + + + \ No newline at end of file diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2headpose.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2headpose.py new file mode 100644 index 00000000..b6afe163 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2headpose.py @@ -0,0 +1,108 @@ +import torch.nn as nn + +from .networks import WaveNet + + + +class Audio2Headpose(nn.Module): + def __init__(self, opt): + super(Audio2Headpose, self).__init__() + self.opt = opt + if self.opt.loss == 'GMM': + output_size = (2 * opt.A2H_GMM_ndim + 1) * opt.A2H_GMM_ncenter + elif self.opt.loss == 'L2': + output_size = opt.A2H_GMM_ndim + # define networks + self.audio_downsample = nn.Sequential( + nn.Linear(in_features=opt.APC_hidden_size * 2, out_features=opt.APC_hidden_size), + nn.BatchNorm1d(opt.APC_hidden_size), + nn.LeakyReLU(0.2), + nn.Linear(opt.APC_hidden_size, opt.APC_hidden_size), + ) + + self.WaveNet = WaveNet(opt.A2H_wavenet_residual_layers, + opt.A2H_wavenet_residual_blocks, + opt.A2H_wavenet_residual_channels, + opt.A2H_wavenet_dilation_channels, + opt.A2H_wavenet_skip_channels, + opt.A2H_wavenet_kernel_size, + opt.time_frame_length, + opt.A2H_wavenet_use_bias, + True, + opt.A2H_wavenet_input_channels, + opt.A2H_GMM_ncenter, + opt.A2H_GMM_ndim, + output_size, + opt.A2H_wavenet_cond_channels) + self.item_length = self.WaveNet.receptive_field + opt.time_frame_length - 1 + + + def forward(self, history_info, audio_features): + ''' + Args: + history_info: [b, T, ndim] + audio_features: [b, 1, nfeas, nwins] + ''' + # APC features: [b, item_length, APC_hidden_size] ==> [b, APC_hidden_size, item_length] + bs, item_len, ndim = audio_features.shape + down_audio_feats = self.audio_downsample(audio_features.reshape(-1, ndim)).reshape(bs, item_len, -1) + pred = self.WaveNet.forward(history_info.permute(0,2,1), down_audio_feats.transpose(1,2)) + + + return pred + + + + +class Audio2Headpose_LSTM(nn.Module): + def __init__(self, opt): + super(Audio2Headpose_LSTM, self).__init__() + self.opt = opt + if self.opt.loss == 'GMM': + output_size = (2 * opt.A2H_GMM_ndim + 1) * opt.A2H_GMM_ncenter + elif self.opt.loss == 'L2': + output_size = opt.A2H_GMM_ndim + # define networks + self.audio_downsample = nn.Sequential( + nn.Linear(in_features=opt.APC_hidden_size * 2, out_features=opt.APC_hidden_size), + nn.BatchNorm1d(opt.APC_hidden_size), + nn.LeakyReLU(0.2), + nn.Linear(opt.APC_hidden_size, opt.APC_hidden_size), + ) + + self.LSTM = nn.LSTM(input_size=opt.APC_hidden_size, + hidden_size=256, + num_layers=3, + dropout=0, + bidirectional=False, + batch_first=True) + self.fc = nn.Sequential( + nn.Linear(in_features=256, out_features=512), + nn.BatchNorm1d(512), + nn.LeakyReLU(0.2), + nn.Linear(512, 512), + nn.BatchNorm1d(512), + nn.LeakyReLU(0.2), + nn.Linear(512, output_size)) + + + def forward(self, audio_features): + ''' + Args: + history_info: [b, T, ndim] + audio_features: [b, 1, nfeas, nwins] + ''' + # APC features: [b, item_length, APC_hidden_size] ==> [b, APC_hidden_size, item_length] + bs, item_len, ndim = audio_features.shape + down_audio_feats = self.audio_downsample(audio_features.reshape(-1, ndim)).reshape(bs, item_len, -1) + output, (hn, cn) = self.LSTM(down_audio_feats) + pred = self.fc(output.reshape(-1, 256)).reshape(bs, item_len, -1) + + + return pred + + + + + + \ No newline at end of file diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2headpose_model.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2headpose_model.py new file mode 100644 index 00000000..5b713891 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/audio2headpose_model.py @@ -0,0 +1,208 @@ +import numpy as np +import torch +from tqdm import tqdm + +from .base_model import BaseModel +from . import networks +from . import audio2headpose +from .losses import GMMLogLoss, Sample_GMM +import torch.nn as nn + + + +class Audio2HeadposeModel(BaseModel): + def __init__(self, opt): + """Initialize the Audio2Headpose class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + + # specify the models you want to save to the disk. The training/test scripts will call and + # define networks + self.model_names = ['Audio2Headpose'] + if opt.feature_decoder == 'WaveNet': + self.Audio2Headpose = networks.init_net(audio2headpose.Audio2Headpose(opt), init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) + elif opt.feature_decoder == 'LSTM': + self.Audio2Headpose = networks.init_net(audio2headpose.Audio2Headpose_LSTM(opt), init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) + + # define only during training time + if self.isTrain: + # losses + self.criterion_GMM = GMMLogLoss(opt.A2H_GMM_ncenter, opt.A2H_GMM_ndim, opt.A2H_GMM_sigma_min).to(self.device) + self.criterion_L2 = nn.MSELoss().cuda() + # optimizer + self.optimizer = torch.optim.Adam([{'params':self.Audio2Headpose.parameters(), + 'initial_lr': opt.lr}], lr=opt.lr, betas=(0.9, 0.99)) + + self.optimizers.append(self.optimizer) + + if opt.continue_train: + self.resume_training() + + + def resume_training(self): + opt = self.opt + ### if continue training, recover previous states + print('Resuming from epoch %s ' % (opt.load_epoch)) + # change epoch count & update schedule settings + opt.epoch_count = int(opt.load_epoch) + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + # print lerning rate + lr = self.optimizers[0].param_groups[0]['lr'] + print('update learning rate: {} -> {}'.format(opt.lr, lr)) + + + + def set_input(self, data, data_info=None): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + """ + if self.opt.feature_decoder == 'WaveNet': + self.headpose_audio_feats, self.history_headpose, self.target_headpose = data + self.headpose_audio_feats = self.headpose_audio_feats.to(self.device) + self.history_headpose = self.history_headpose.to(self.device) + self.target_headpose = self.target_headpose.to(self.device) + elif self.opt.feature_decoder == 'LSTM': + self.headpose_audio_feats, self.target_headpose = data + self.headpose_audio_feats = self.headpose_audio_feats.to(self.device) + self.target_headpose = self.target_headpose.to(self.device) + + + + def forward(self): + ''' + Args: + history_landmarks: [b, T, ndim] + audio_features: [b, 1, nfeas, nwins] + Returns: + preds: [b, T, output_channels] + ''' + + if self.opt.audio_windows == 2: + bs, item_len, ndim = self.headpose_audio_feats.shape + self.headpose_audio_feats = self.headpose_audio_feats.reshape(bs, -1, ndim * 2) + else: + bs, item_len, _, ndim = self.headpose_audio_feats.shape + if self.opt.feature_decoder == 'WaveNet': + self.preds_headpose = self.Audio2Headpose.forward(self.history_headpose, self.headpose_audio_feats) + elif self.opt.feature_decoder == 'LSTM': + self.preds_headpose = self.Audio2Headpose.forward(self.headpose_audio_feats) + + + def calculate_loss(self): + """ calculate loss in detail, only forward pass included""" + if self.opt.loss == 'GMM': + self.loss_GMM = self.criterion_GMM(self.preds_headpose, self.target_headpose) + self.loss = self.loss_GMM + elif self.opt.loss == 'L2': + self.loss_L2 = self.criterion_L2(self.preds_headpose, self.target_headpose) + self.loss = self.loss_L2 + + if not self.opt.smooth_loss == 0: + mu_gen = Sample_GMM(self.preds_headpose, + self.Audio2Headpose.module.WaveNet.ncenter, + self.Audio2Headpose.module.WaveNet.ndim, + sigma_scale=0) + self.smooth_loss = (mu_gen[:,2:] + self.target_headpose[:,:-2] - 2 * self.target_headpose[:,1:-1]).mean(dim=2).abs().mean() + self.loss += self.smooth_loss * self.opt.smooth_loss + + + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + self.calculate_loss() + self.loss.backward() + + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.optimizer.zero_grad() # clear optimizer parameters grad + self.forward() # forward pass + self.backward() # calculate loss and gradients + self.optimizer.step() # update gradients + + + def validate(self): + """ validate process """ + with torch.no_grad(): + self.forward() + self.calculate_loss() + + + def generate_sequences(self, audio_feats, pre_headpose, fill_zero=True, sigma_scale=0.0, opt=[]): + ''' generate landmark sequences given audio and a initialized landmark. + Note that the audio input should have the same sample rate as the training. + Args: + audio_sequences: [n,], in numpy + init_landmarks: [npts, 2], in numpy + sample_rate: audio sample rate, should be same as training process. + method(str): optional, how to generate the sequence, indeed it is the + loss function during training process. Options are 'L2' or 'GMM'. + Reutrns: + landmark_sequences: [T, npts, 2] predition landmark sequences + ''' + + frame_future = opt.frame_future + audio_feats = audio_feats.reshape(-1, 512 * 2) + nframe = audio_feats.shape[0] - frame_future + pred_headpose = np.zeros([nframe, opt.A2H_GMM_ndim]) + + if opt.feature_decoder == 'WaveNet': + # fill zero or not + if fill_zero == True: + # headpose + audio_feats_insert = np.repeat(audio_feats[0], opt.A2H_receptive_field - 1) + audio_feats_insert = audio_feats_insert.reshape(-1, opt.A2H_receptive_field - 1).T + audio_feats = np.concatenate([audio_feats_insert, audio_feats]) + # history headpose + history_headpose = np.repeat(pre_headpose, opt.A2H_receptive_field) + history_headpose = history_headpose.reshape(-1, opt.A2H_receptive_field).T + history_headpose = torch.from_numpy(history_headpose).unsqueeze(0).float().to(self.device) + infer_start = 0 + else: + return None + + # evaluate mode + self.Audio2Headpose.eval() + + with torch.no_grad(): + for i in tqdm(range(infer_start, nframe), desc='generating headpose'): + history_start = i - infer_start + input_audio_feats = audio_feats[history_start + frame_future: history_start + frame_future + opt.A2H_receptive_field] + input_audio_feats = torch.from_numpy(input_audio_feats).unsqueeze(0).float().to(self.device) + + if self.opt.feature_decoder == 'WaveNet': + preds = self.Audio2Headpose.forward(history_headpose, input_audio_feats) + elif self.opt.feature_decoder == 'LSTM': + preds = self.Audio2Headpose.forward(input_audio_feats) + + if opt.loss == 'GMM': + pred_data = Sample_GMM(preds, opt.A2H_GMM_ncenter, opt.A2H_GMM_ndim, sigma_scale=sigma_scale) + elif opt.loss == 'L2': + pred_data = preds + + # get predictions + pred_headpose[i] = pred_data[0,0].cpu().detach().numpy() + history_headpose = torch.cat((history_headpose[:,1:,:], pred_data.to(self.device)), dim=1) # add in time-axis + + return pred_headpose + + elif opt.feature_decoder == 'LSTM': + self.Audio2Headpose.eval() + with torch.no_grad(): + input = torch.from_numpy(audio_feats).unsqueeze(0).float().to(self.device) + preds = self.Audio2Headpose.forward(input) + if opt.loss == 'GMM': + pred_data = Sample_GMM(preds, opt.A2H_GMM_ncenter, opt.A2H_GMM_ndim, sigma_scale=sigma_scale) + elif opt.loss == 'L2': + pred_data = preds + # get predictions + pred_headpose = pred_data[0].cpu().detach().numpy() + + return pred_headpose + + + + + \ No newline at end of file diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/base_model.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/base_model.py new file mode 100644 index 00000000..b654005b --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/base_model.py @@ -0,0 +1,272 @@ +import os +import torch +import numpy as np +from collections import OrderedDict +from abc import ABC, abstractmethod +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this function, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + # get device name: CPU or GPU + # if self.gpu_ids == '-1': + # self.device = torch.device('cpu') + # self.gpu_ids = opt.gpu_ids == [] + # else: + # self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if len(self.gpu_ids) > 0 else torch.device('cpu') + + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + # torch speed up training + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + self.load_networks(opt.load_epoch) + self.print_networks(opt.verbose) + + + def train(self): + """Make models train mode during train time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.train(mode=True) + + + def eval(self): + """Make models eval mode during test time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.eval() + + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self): + """ Return image paths that are used to load current data""" + return self.image_paths + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch, train_info=None): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_%s.pkl' % (epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, name) + torch.save(net.state_dict(), save_path) + if train_info is not None: + epoch, epoch_iter = train_info + iter_path = os.path.join(self.save_dir, 'iter.txt') + np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') + + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + + + for name in self.model_names: + if isinstance(name, str): + if epoch[-3:] == 'pkl': + load_path = epoch + else: + load_filename = '%s_%s.pkl' % (epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, name) +# if isinstance(net, torch.nn.DataParallel): +# net = net.module + if os.path.exists(load_path): + state_dict = torch.load(load_path, map_location=str(self.device)) + if self.device == torch.device('cpu'): + for key in list(state_dict.keys()): + state_dict[key[7:]] = state_dict.pop(key) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + print('loading the model from %s' % load_path) + net.load_state_dict(state_dict, strict=False) + else: + print('No model weight file:', load_path, 'initialize model without pre-trained weights.') + if self.isTrain == False: + raise ValueError('We are now in inference process, no pre-trained model found! Check the model checkpoint!') + + +# if isinstance(net, torch.nn.DataParallel): +# net = net.module + + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + +# state_dict = torch.load(load_path, map_location=str(self.device)) +# if hasattr(state_dict, '_metadata'): +# del state_dict._metadata +# +# # patch InstanceNorm checkpoints prior to 0.4 +# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop +# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) +# net.load_state_dict(state_dict) + + + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/feature2face_D.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/feature2face_D.py new file mode 100644 index 00000000..73b7ff65 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/feature2face_D.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn + + +from .networks import MultiscaleDiscriminator +from torch.cuda.amp import autocast as autocast + + + +class Feature2Face_D(nn.Module): + def __init__(self, opt): + super(Feature2Face_D, self).__init__() + # initialize + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor + self.tD = opt.n_frames_D + self.output_nc = opt.output_nc + + # define networks + self.netD = MultiscaleDiscriminator(23 + 3, opt.ndf, opt.n_layers_D, opt.num_D, not opt.no_ganFeat) + + print('---------- Discriminator networks initialized -------------') + print('-----------------------------------------------------------') + + #@autocast() + def forward(self, input): + if self.opt.fp16: + with autocast(): + pred = self.netD(input) + else: + pred = self.netD(input) + + return pred + + + + + + + + + + + + + + + + + diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/feature2face_G.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/feature2face_G.py new file mode 100644 index 00000000..37234dff --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/feature2face_G.py @@ -0,0 +1,46 @@ +import torch.nn as nn + +from .networks import Feature2FaceGenerator_Unet, Feature2FaceGenerator_normal, Feature2FaceGenerator_large + +from torch.cuda.amp import autocast as autocast + + +class Feature2Face_G(nn.Module): + def __init__(self, opt): + super(Feature2Face_G, self).__init__() + # initialize + self.opt = opt + self.isTrain = opt.isTrain + # define net G + + if opt.size == 'small': + self.netG = Feature2FaceGenerator_Unet(input_nc=23, output_nc=3, num_downs=opt.n_downsample_G, ngf=opt.ngf) + elif opt.size == 'normal': + self.netG = Feature2FaceGenerator_normal(input_nc=13, output_nc=3, num_downs=opt.n_downsample_G, ngf=opt.ngf) + elif opt.size == 'large': + self.netG = Feature2FaceGenerator_large(input_nc=13, output_nc=3, num_downs=opt.n_downsample_G, ngf=opt.ngf) + + print('---------- Generator networks initialized -------------') + print('-------------------------------------------------------') + + + def forward(self, input): + if self.opt.fp16: + with autocast(): + fake_pred = self.netG(input) + else: + fake_pred = self.netG(input) + + return fake_pred + + + + + + + + + + + + diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/feature2face_model.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/feature2face_model.py new file mode 100644 index 00000000..b2fadb65 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/feature2face_model.py @@ -0,0 +1,246 @@ +import os +import os.path +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import autocast as autocast + +from . import networks +from . import feature2face_G +from .base_model import BaseModel +from .losses import GANLoss, MaskedL1Loss, VGGLoss + + + +class Feature2FaceModel(BaseModel): + def __init__(self, opt): + """Initialize the Feature2Face class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + self.Tensor = torch.cuda.FloatTensor + # specify the models you want to save to the disk. The training/test scripts will call and + # define networks + self.model_names = ['Feature2Face_G'] + self.Feature2Face_G = networks.init_net(feature2face_G.Feature2Face_G(opt), init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) + if self.isTrain: + if not opt.no_discriminator: + self.model_names += ['Feature2Face_D'] + from . import feature2face_D + self.Feature2Face_D = networks.init_net(feature2face_D.Feature2Face_D(opt), init_type='normal', init_gain=0.02, gpu_ids=opt.gpu_ids) + + + # define only during training time + if self.isTrain: + # define losses names + self.loss_names_G = ['L1', 'VGG', 'Style', 'loss_G_GAN', 'loss_G_FM'] + # criterion + self.criterionMaskL1 = MaskedL1Loss().cuda() + self.criterionL1 = nn.L1Loss().cuda() + self.criterionVGG = VGGLoss.cuda() + self.criterionFlow = nn.L1Loss().cuda() + + # initialize optimizer G + if opt.TTUR: + beta1, beta2 = 0, 0.9 + lr = opt.lr / 2 + else: + beta1, beta2 = opt.beta1, 0.999 + lr = opt.lr + self.optimizer_G = torch.optim.Adam([{'params': self.Feature2Face_G.module.parameters(), + 'initial_lr': lr}], + lr=lr, + betas=(beta1, beta2)) + self.optimizers.append(self.optimizer_G) + + # fp16 training + if opt.fp16: + self.scaler = torch.cuda.amp.GradScaler() + + # discriminator setting + if not opt.no_discriminator: + self.criterionGAN = GANLoss(opt.gan_mode, tensor=self.Tensor) + self.loss_names_D = ['D_real', 'D_fake'] + # initialize optimizer D + if opt.TTUR: + beta1, beta2 = 0, 0.9 + lr = opt.lr * 2 + else: + beta1, beta2 = opt.beta1, 0.999 + lr = opt.lr + self.optimizer_D = torch.optim.Adam([{'params': self.Feature2Face_D.module.netD.parameters(), + 'initial_lr': lr}], + lr=lr, + betas=(beta1, beta2)) + self.optimizers.append(self.optimizer_D) + + + def init_paras(self, dataset): + opt = self.opt + iter_path = os.path.join(self.save_dir, 'iter.txt') + start_epoch, epoch_iter = 1, 0 + ### if continue training, recover previous states + if opt.continue_train: + if os.path.exists(iter_path): + start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) + print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) + # change epoch count & update schedule settings + opt.epoch_count = start_epoch + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + # print lerning rate + lr = self.optimizers[0].param_groups[0]['lr'] + print('update learning rate: {} -> {}'.format(opt.lr, lr)) + else: + print('not found training log, hence training from epoch 1') + # change training sequence length +# if start_epoch > opt.nepochs_step: +# dataset.dataset.update_training_batch((start_epoch-1)//opt.nepochs_step) + + + total_steps = (start_epoch-1) * len(dataset) + epoch_iter + total_steps = total_steps // opt.print_freq * opt.print_freq + + return start_epoch, opt.print_freq, total_steps, epoch_iter + + + + def set_input(self, data, data_info=None): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + """ + self.feature_map, self.cand_image, self.tgt_image, self.facial_mask = \ + data['feature_map'], data['cand_image'], data['tgt_image'], data['weight_mask'] + self.feature_map = self.feature_map.to(self.device) + self.cand_image = self.cand_image.to(self.device) + self.tgt_image = self.tgt_image.to(self.device) +# self.facial_mask = self.facial_mask.to(self.device) + + + def forward(self): + ''' forward pass for feature2Face + ''' + self.input_feature_maps = torch.cat([self.feature_map, self.cand_image], dim=1) + self.fake_pred = self.Feature2Face_G(self.input_feature_maps) + + + + + def backward_G(self): + """Calculate GAN and other loss for the generator""" + # GAN loss + real_AB = torch.cat((self.input_feature_maps, self.tgt_image), dim=1) + fake_AB = torch.cat((self.input_feature_maps, self.fake_pred), dim=1) + pred_real = self.Feature2Face_D(real_AB) + pred_fake = self.Feature2Face_D(fake_AB) + loss_G_GAN = self.criterionGAN(pred_fake, True) + # L1, vgg, style loss + loss_l1 = self.criterionL1(self.fake_pred, self.tgt_image) * self.opt.lambda_L1 +# loss_maskL1 = self.criterionMaskL1(self.fake_pred, self.tgt_image, self.facial_mask * self.opt.lambda_mask) + loss_vgg, loss_style = self.criterionVGG(self.fake_pred, self.tgt_image, style=True) + loss_vgg = torch.mean(loss_vgg) * self.opt.lambda_feat + loss_style = torch.mean(loss_style) * self.opt.lambda_feat + # feature matching loss + loss_FM = self.compute_FeatureMatching_loss(pred_fake, pred_real) + + # combine loss and calculate gradients + + if not self.opt.fp16: + self.loss_G = loss_G_GAN + loss_l1 + loss_vgg + loss_style + loss_FM #+ loss_maskL1 + self.loss_G.backward() + else: + with autocast(): + self.loss_G = loss_G_GAN + loss_l1 + loss_vgg + loss_style + loss_FM #+ loss_maskL1 + self.scaler.scale(self.loss_G).backward() + + self.loss_dict = {**self.loss_dict, **dict(zip(self.loss_names_G, [loss_l1, loss_vgg, loss_style, loss_G_GAN, loss_FM]))} + + + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # GAN loss + real_AB = torch.cat((self.input_feature_maps, self.tgt_image), dim=1) + fake_AB = torch.cat((self.input_feature_maps, self.fake_pred), dim=1) + pred_real = self.Feature2Face_D(real_AB) + pred_fake = self.Feature2Face_D(fake_AB.detach()) + with autocast(): + loss_D_real = self.criterionGAN(pred_real, True) * 2 + loss_D_fake = self.criterionGAN(pred_fake, False) + + self.loss_D = (loss_D_fake + loss_D_real) * 0.5 + + self.loss_dict = dict(zip(self.loss_names_D, [loss_D_real, loss_D_fake])) + + if not self.opt.fp16: + self.loss_D.backward() + else: + self.scaler.scale(self.loss_D).backward() + + + def compute_FeatureMatching_loss(self, pred_fake, pred_real): + # GAN feature matching loss + loss_FM = torch.zeros(1).cuda() + feat_weights = 4.0 / (self.opt.n_layers_D + 1) + D_weights = 1.0 / self.opt.num_D + for i in range(min(len(pred_fake), self.opt.num_D)): + for j in range(len(pred_fake[i])): + loss_FM += D_weights * feat_weights * \ + self.criterionL1(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat + + return loss_FM + + + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + # only train single image generation + ## forward + self.forward() + # update D + self.set_requires_grad(self.Feature2Face_D, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + if not self.opt.fp16: + self.backward_D() # calculate gradients for D + self.optimizer_D.step() # update D's weights + else: + with autocast(): + self.backward_D() + self.scaler.step(self.optimizer_D) + + + # update G + self.set_requires_grad(self.Feature2Face_D, False) # D requires no gradients when optimizing G + self.optimizer_G.zero_grad() # set G's gradients to zero + if not self.opt.fp16: + self.backward_G() # calculate graidents for G + self.optimizer_G.step() # udpate G's weights + else: + with autocast(): + self.backward_G() + self.scaler.step(self.optimizer_G) + self.scaler.update() + + + def inference(self, feature_map, cand_image): + """ inference process """ + with torch.no_grad(): + if cand_image == None: + input_feature_maps = feature_map + else: + input_feature_maps = torch.cat([feature_map, cand_image], dim=1) + if not self.opt.fp16: + fake_pred = self.Feature2Face_G(input_feature_maps) + else: + with autocast(): + fake_pred = self.Feature2Face_G(input_feature_maps) + return fake_pred + + + + + + + + + diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/losses.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/losses.py new file mode 100644 index 00000000..0834a72d --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/losses.py @@ -0,0 +1,278 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import math +import torch.nn.functional as F + + +class GMMLogLoss(nn.Module): + ''' compute the GMM loss between model output and the groundtruth data. + Args: + ncenter: numbers of gaussian distribution + ndim: dimension of each gaussian distribution + sigma_bias: + sigma_min: current we do not use it. + ''' + def __init__(self, ncenter, ndim, sigma_min=0.03): + super(GMMLogLoss,self).__init__() + self.ncenter = ncenter + self.ndim = ndim + self.sigma_min = sigma_min + + + def forward(self, output, target): + ''' + Args: + output: [b, T, ncenter + ncenter * ndim * 2]: + [:, :, : ncenter] shows each gaussian probability + [:, :, ncenter : ncenter + ndim * ncenter] shows the average values of each dimension of each gaussian + [: ,:, ncenter + ndim * ncenter : ncenter + ndim * 2 * ncenter] show the negative log sigma of each dimension of each gaussian + target: [b, T, ndim], the ground truth target landmark data is shown here + To maximize the log-likelihood equals to minimize the negative log-likelihood. + NOTE: It is unstable to directly compute the log results of sigma, e.g. ln(-0.1) as we need to clip the sigma results + into positive. Hence here we predict the negative log sigma results to avoid numerical instablility, which mean: + `` sigma = 1/exp(predict), predict = -ln(sigma) `` + Also, it will be just the 'B' term below! + Currently we only implement single gaussian distribution, hence the first values of pred are meaningless. + For single gaussian distribution: + L(mu, sigma) = -n/2 * ln(2pi * sigma^2) - 1 / (2 x sigma^2) * sum^n (x_i - mu)^2 (n for prediction times, n=1 for one frame, x_i for gt) + = -1/2 * ln(2pi) - 1/2 * ln(sigma^2) - 1/(2 x sigma^2) * (x - mu)^2 + == min -L(mu, sgima) = 0.5 x ln(2pi) + 0.5 x ln(sigma^2) + 1/(2 x sigma^2) * (x - mu)^2 + = 0.5 x ln_2PI + ln(sigma) + 0.5 x (MU_DIFF/sigma)^2 + = A - B + C + In batch and Time sample, b and T are summed and averaged. + ''' + b, T, _ = target.shape + # read prediction paras + mus = output[:, :, self.ncenter : (self.ncenter + self.ncenter * self.ndim)].view(b, T, self.ncenter, self.ndim) # [b, T, ncenter, ndim] + + # apply min sigma + neg_log_sigmas_out = output[:, :, (self.ncenter + self.ncenter * self.ndim):].view(b, T, self.ncenter, self.ndim) # [b, T, ncenter, ndim] + inv_sigmas_min = torch.ones(neg_log_sigmas_out.size()).cuda() * (1. / self.sigma_min) + inv_sigmas_min_log = torch.log(inv_sigmas_min) + neg_log_sigmas = torch.min(neg_log_sigmas_out, inv_sigmas_min_log) + + inv_sigmas = torch.exp(neg_log_sigmas) + # replicate the target of ncenter to minus mu + target_rep = target.unsqueeze(2).expand(b, T, self.ncenter, self.ndim) # [b, T, ncenter, ndim] + MU_DIFF = target_rep - mus # [b, T, ncenter, ndim] + # sigma process + A = 0.5 * math.log(2 * math.pi) # 0.9189385332046727 + B = neg_log_sigmas # [b, T, ncenter, ndim] + C = 0.5 * (MU_DIFF * inv_sigmas)**2 # [b, T, ncenter, ndim] + negative_loglikelihood = A - B + C # [b, T, ncenter, ndim] + + return negative_loglikelihood.mean() + + +def Sample_GMM(gmm_params, ncenter, ndim, weight_smooth = 0.0, sigma_scale = 0.0): + ''' Sample values from a given a GMM distribution. + Args: + gmm_params: [b, target_length, (2 * ndim + 1) * ncenter], including the + distribution weights, average and sigma + ncenter: numbers of gaussian distribution + ndim: dimension of each gaussian distribution + weight_smooth: float, smooth the gaussian distribution weights + sigma_scale: float, adjust the gaussian scale, larger for sharper prediction, + 0 for zero sigma which always return average values + Returns: + current_sample: [] + ''' + # reshape as [b*T, (2 * ndim + 1) * ncenter] + b, T, _ = gmm_params.shape + gmm_params_cpu = gmm_params.cpu().view(-1, (2 * ndim + 1) * ncenter) + # compute each distrubution probability + prob = nn.functional.softmax(gmm_params_cpu[:, : ncenter] * (1 + weight_smooth), dim=1) + # select the gaussian distribution according to their weights + selected_idx = torch.multinomial(prob, num_samples=1, replacement=True) + + mu = gmm_params_cpu[:, ncenter : ncenter + ncenter * ndim] + # please note that we use -logsigma as output, hence here we need to take the negative + sigma = torch.exp(-gmm_params_cpu[:, ncenter + ncenter * ndim:]) * sigma_scale +# print('sigma average:', sigma.mean()) + + selected_sigma = torch.empty(b*T, ndim).float() + selected_mu = torch.empty(b*T, ndim).float() + current_sample = torch.randn(b*T, ndim).float() +# current_sample = test_sample + + for i in range(b*T): + idx = selected_idx[i, 0] + selected_sigma[i, :] = sigma[i, idx * ndim:(idx + 1) * ndim] + selected_mu[i, :] = mu[i, idx * ndim:(idx + 1) * ndim] + + # sample with sel sigma and sel mean + current_sample = current_sample * selected_sigma + selected_mu + # cur_sample = sel_mu +# return current_sample.unsqueeze(1).cuda() + + if torch.cuda.is_available(): + return current_sample.reshape(b, T, -1).cuda() + else: + return current_sample.reshape(b, T, -1) + + + +class GANLoss(nn.Module): + def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, + tensor=torch.FloatTensor): + super(GANLoss, self).__init__() + self.real_label = target_real_label + self.fake_label = target_fake_label + self.real_label_var = None + self.fake_label_var = None + self.Tensor = tensor + if use_lsgan: + self.loss = nn.MSELoss() + else: + self.loss = nn.BCELoss() + + def get_target_tensor(self, input, target_is_real): + target_tensor = None + gpu_id = input.get_device() + if target_is_real: + create_label = ((self.real_label_var is None) or + (self.real_label_var.numel() != input.numel())) + if create_label: + real_tensor = self.Tensor(input.size()).cuda(gpu_id).fill_(self.real_label) + self.real_label_var = Variable(real_tensor, requires_grad=False) + target_tensor = self.real_label_var + else: + create_label = ((self.fake_label_var is None) or + (self.fake_label_var.numel() != input.numel())) + if create_label: + fake_tensor = self.Tensor(input.size()).cuda(gpu_id).fill_(self.fake_label) + self.fake_label_var = Variable(fake_tensor, requires_grad=False) + target_tensor = self.fake_label_var + return target_tensor + + def __call__(self, input, target_is_real): + if isinstance(input[0], list): + loss = 0 + for input_i in input: + pred = input_i[-1] + target_tensor = self.get_target_tensor(pred, target_is_real) + loss += self.loss(pred, target_tensor) + return loss + else: + target_tensor = self.get_target_tensor(input[-1], target_is_real) + return self.loss(input[-1], target_tensor) + + + + +class VGGLoss(nn.Module): + def __init__(self, model=None): + super(VGGLoss, self).__init__() + if model is None: + self.vgg = Vgg19() + else: + self.vgg = model + + self.vgg.cuda() + # self.vgg.eval() + self.criterion = nn.L1Loss() + self.style_criterion = StyleLoss() + self.weights = [1.0, 1.0, 1.0, 1.0, 1.0] + self.style_weights = [1.0, 1.0, 1.0, 1.0, 1.0] + # self.weights = [5.0, 1.0, 0.5, 0.4, 0.8] + # self.style_weights = [10e4, 1000, 50, 15, 50] + + def forward(self, x, y, style=False): + x_vgg, y_vgg = self.vgg(x), self.vgg(y) + loss = 0 + if style: + # return both perceptual loss and style loss. + style_loss = 0 + for i in range(len(x_vgg)): + this_loss = (self.weights[i] * + self.criterion(x_vgg[i], y_vgg[i].detach())) + this_style_loss = (self.style_weights[i] * + self.style_criterion(x_vgg[i], y_vgg[i].detach())) + loss += this_loss + style_loss += this_style_loss + return loss, style_loss + + for i in range(len(x_vgg)): + this_loss = (self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())) + loss += this_loss + return loss + + +def gram_matrix(input): + a, b, c, d = input.size() # a=batch size(=1) + # b=number of feature maps + # (c,d)=dimensions of a f. map (N=c*d) + features = input.view(a * b, c * d) # resise F_XL into \hat F_XL + G = torch.mm(features, features.t()) # compute the gram product + # we 'normalize' the values of the gram matrix + # by dividing by the number of element in each feature maps. + return G.div(a * b * c * d) + + +class StyleLoss(nn.Module): + def __init__(self): + super(StyleLoss, self).__init__() + + def forward(self, x, y): + Gx = gram_matrix(x) + Gy = gram_matrix(y) + return F.mse_loss(Gx, Gy) * 30000000 + + + +class MaskedL1Loss(nn.Module): + def __init__(self): + super(MaskedL1Loss, self).__init__() + self.criterion = nn.L1Loss() + + def forward(self, input, target, mask): + mask = mask.expand(-1, input.size()[1], -1, -1) + loss = self.criterion(input * mask, target * mask) + return loss + + + +from torchvision import models +class Vgg19(nn.Module): + def __init__(self, requires_grad=False): + super(Vgg19, self).__init__() + vgg_pretrained_features = models.vgg19(pretrained=True).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + for x in range(2): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(2, 7): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(7, 12): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(12, 21): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(21, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h_relu1 = self.slice1(X) + h_relu2 = self.slice2(h_relu1) + h_relu3 = self.slice3(h_relu2) + h_relu4 = self.slice4(h_relu3) + h_relu5 = self.slice5(h_relu4) + out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] + return out + + + + + + + + + + + diff --git a/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/networks.py b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/networks.py new file mode 100644 index 00000000..2a684481 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/LiveSpeechPortraits/networks.py @@ -0,0 +1,873 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.optim import lr_scheduler +from torch.nn import init +import functools + + +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence + + + +############################################################################### +# The detailed network architecture implementation for each model +############################################################################### + +class APC_encoder(nn.Module): + def __init__(self, + mel_dim, + hidden_size, + num_layers, + residual): + super(APC_encoder, self).__init__() + + input_size = mel_dim + + in_sizes = ([input_size] + [hidden_size] * (num_layers - 1)) + out_sizes = [hidden_size] * num_layers + self.rnns = nn.ModuleList( + [nn.GRU(input_size=in_size, hidden_size=out_size, batch_first=True) for (in_size, out_size) in zip(in_sizes, out_sizes)]) + + self.rnn_residual = residual + + def forward(self, inputs, lengths): + ''' + input: + inputs: (batch_size, seq_len, mel_dim) + lengths: (batch_size,) + + return: + predicted_mel: (batch_size, seq_len, mel_dim) + internal_reps: (num_layers + x, batch_size, seq_len, rnn_hidden_size), + where x is 1 if there's a prenet, otherwise 0 + ''' + with torch.no_grad(): + seq_len = inputs.size(1) + packed_rnn_inputs = pack_padded_sequence(inputs, lengths, True) + + for i, layer in enumerate(self.rnns): + packed_rnn_outputs, _ = layer(packed_rnn_inputs) + + rnn_outputs, _ = pad_packed_sequence( + packed_rnn_outputs, True, total_length=seq_len) + # outputs: (batch_size, seq_len, rnn_hidden_size) + + if i + 1 < len(self.rnns): + rnn_inputs, _ = pad_packed_sequence( + packed_rnn_inputs, True, total_length=seq_len) + # rnn_inputs: (batch_size, seq_len, rnn_hidden_size) + if self.rnn_residual and rnn_inputs.size(-1) == rnn_outputs.size(-1): + # Residual connections + rnn_outputs = rnn_outputs + rnn_inputs + packed_rnn_inputs = pack_padded_sequence(rnn_outputs, lengths, True) + + + return rnn_outputs + + + + +class WaveNet(nn.Module): + ''' This is a complete implementation of WaveNet architecture, mainly composed + of several residual blocks and some other operations. + Args: + batch_size: number of batch size + residual_layers: number of layers in each residual blocks + residual_blocks: number of residual blocks + dilation_channels: number of channels for the dilated convolution + residual_channels: number of channels for the residual connections + skip_channels: number of channels for the skip connections + end_channels: number of channels for the end convolution + classes: Number of possible values each sample can have as output + kernel_size: size of dilation convolution kernel + output_length(int): Number of samples that are generated for each input + use_bias: whether bias is used in each layer. + cond(bool): whether condition information are applied. if cond == True: + cond_channels: channel number of condition information + `` loss(str): GMM loss is adopted. `` + ''' + def __init__(self, + residual_layers = 10, + residual_blocks = 3, + dilation_channels = 32, + residual_channels = 32, + skip_channels = 256, + kernel_size = 2, + output_length = 16, + use_bias = False, + cond = True, + input_channels = 128, + ncenter = 1, + ndim = 73*2, + output_channels = 73*3, + cond_channels = 256, + activation = 'leakyrelu'): + super(WaveNet, self).__init__() + + self.layers = residual_layers + self.blocks = residual_blocks + self.dilation_channels = dilation_channels + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.input_channels = input_channels + self.ncenter = ncenter + self.ndim = ndim +# self.output_channels = (2 * self.ndim + 1) * self.ncenter + self.output_channels = output_channels + self.kernel_size = kernel_size + self.output_length = output_length + self.bias = use_bias + self.cond = cond + self.cond_channels = cond_channels + + # build modules + self.dilations = [] + self.dilation_queues = [] + residual_blocks = [] + self.receptive_field = 1 + + # 1x1 convolution to create channels + self.start_conv1 = nn.Conv1d(in_channels=self.input_channels, + out_channels=self.residual_channels, + kernel_size=1, + bias=True) + self.start_conv2 = nn.Conv1d(in_channels=self.residual_channels, + out_channels=self.residual_channels, + kernel_size=1, + bias=True) + if activation == 'relu': + self.activation = nn.ReLU(inplace = True) + elif activation == 'leakyrelu': + self.activation = nn.LeakyReLU(0.2) + self.drop_out2D = nn.Dropout2d(p=0.5) + + + # build residual blocks + for b in range(self.blocks): + new_dilation = 1 + additional_scope = kernel_size - 1 + for i in range(self.layers): + # create current residual block + residual_blocks.append(residual_block(dilation = new_dilation, + dilation_channels = self.dilation_channels, + residual_channels = self.residual_channels, + skip_channels = self.skip_channels, + kernel_size = self.kernel_size, + use_bias = self.bias, + cond = self.cond, + cond_channels = self.cond_channels)) + new_dilation *= 2 + + self.receptive_field += additional_scope + additional_scope *= 2 + + self.residual_blocks = nn.ModuleList(residual_blocks) + # end convolutions + + self.end_conv_1 = nn.Conv1d(in_channels = self.skip_channels, + out_channels = self.output_channels, + kernel_size = 1, + bias = True) + self.end_conv_2 = nn.Conv1d(in_channels = self.output_channels, + out_channels = self.output_channels, + kernel_size = 1, + bias = True) + + + def parameter_count(self): + par = list(self.parameters()) + s = sum([np.prod(list(d.size())) for d in par]) + return s + + def forward(self, input, cond=None): + ''' + Args: + input: [b, ndim, T] + cond: [b, nfeature, T] + Returns: + res: [b, T, ndim] + ''' + # dropout + x = self.drop_out2D(input) + + # preprocess + x = self.activation(self.start_conv1(x)) + x = self.activation(self.start_conv2(x)) + skip = 0 +# for i in range(self.blocks * self.layers): + for i, dilation_block in enumerate(self.residual_blocks): + x, current_skip = self.residual_blocks[i](x, cond) + skip += current_skip + + # postprocess + res = self.end_conv_1(self.activation(skip)) + res = self.end_conv_2(self.activation(res)) + + # cut the output size + res = res[:, :, -self.output_length:] # [b, ndim, T] + res = res.transpose(1, 2) # [b, T, ndim] + + return res + + + +class residual_block(nn.Module): + ''' + This is the implementation of a residual block in wavenet model. Every + residual block takes previous block's output as input. The forward pass of + each residual block can be illusatrated as below: + + ######################### Current Residual Block ########################## + # |-----------------------*residual*--------------------| # + # | | # + # | |-- dilated conv -- tanh --| | # + # -> -|-- pad--| * ---- |-- 1x1 -- + --> *input* # + # |-- dilated conv -- sigm --| | # + # 1x1 # + # | # + # ---------------------------------------------> + -------------> *skip* # + ########################################################################### + As shown above, each residual block returns two value: 'input' and 'skip': + 'input' is indeed this block's output and also is the next block's input. + 'skip' is the skip data which will be added finally to compute the prediction. + The input args own the same meaning in the WaveNet class. + + ''' + def __init__(self, + dilation, + dilation_channels = 32, + residual_channels = 32, + skip_channels = 256, + kernel_size = 2, + use_bias = False, + cond = True, + cond_channels = 128): + super(residual_block, self).__init__() + + self.dilation = dilation + self.dilation_channels = dilation_channels + self.residual_channels = residual_channels + self.skip_channels = skip_channels + self.kernel_size = kernel_size + self.bias = use_bias + self.cond = cond + self.cond_channels = cond_channels + # zero padding to the left of the sequence. + self.padding = (int((self.kernel_size - 1) * self.dilation), 0) + + # dilated convolutions + self.filter_conv= nn.Conv1d(in_channels = self.residual_channels, + out_channels = self.dilation_channels, + kernel_size = self.kernel_size, + dilation = self.dilation, + bias = self.bias) + + self.gate_conv = nn.Conv1d(in_channels = self.residual_channels, + out_channels = self.dilation_channels, + kernel_size = self.kernel_size, + dilation = self.dilation, + bias = self.bias) + + # 1x1 convolution for residual connections + self.residual_conv = nn.Conv1d(in_channels = self.dilation_channels, + out_channels = self.residual_channels, + kernel_size = 1, + bias = self.bias) + + # 1x1 convolution for skip connections + self.skip_conv = nn.Conv1d(in_channels = self.dilation_channels, + out_channels = self.skip_channels, + kernel_size = 1, + bias = self.bias) + + # condition conv, no dilation + if self.cond == True: + self.cond_filter_conv = nn.Conv1d(in_channels = self.cond_channels, + out_channels = self.dilation_channels, + kernel_size = 1, + bias = True) + self.cond_gate_conv = nn.Conv1d(in_channels = self.cond_channels, + out_channels = self.dilation_channels, + kernel_size = 1, + bias = True) + + + def forward(self, input, cond=None): + if self.cond is True and cond is None: + raise RuntimeError("set using condition to true, but no cond tensor inputed") + + x_pad = F.pad(input, self.padding) + # filter + filter = self.filter_conv(x_pad) + # gate + gate = self.gate_conv(x_pad) + + if self.cond == True and cond is not None: + filter_cond = self.cond_filter_conv(cond) + gate_cond = self.cond_gate_conv(cond) + # add cond results + filter = filter + filter_cond + gate = gate + gate_cond + + # element-wise multiple + filter = torch.tanh(filter) + gate = torch.sigmoid(gate) + x = filter * gate + + # residual and skip + residual = self.residual_conv(x) + input + skip = self.skip_conv(x) + + + return residual, skip + + + + +## 2D convolution layers +def conv2d(batch_norm, in_planes, out_planes, kernel_size=3, stride=1): + if batch_norm: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), + nn.BatchNorm2d(out_planes), + nn.LeakyReLU(0.2, inplace=True) + ) + else: + return nn.Sequential( + nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), + nn.LeakyReLU(0.2, inplace=True) + ) + + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[], useDDP=False): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + if useDDP: + net = net().to(gpu_ids) + net = DDP(net, device_ids=gpu_ids) # DDP + print(f'use DDP to apply models on {gpu_ids}') + else: + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch - opt.n_epochs) / float(opt.n_epochs_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule, last_epoch=opt.epoch_count-2) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=opt.gamma, last_epoch=opt.epoch_count-2) + for _ in range(opt.epoch_count-2): + scheduler.step() + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1 and hasattr(m, 'weight'): + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm2d') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def print_network(net): + if isinstance(net, list): + net = net[0] + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + + + + +class Feature2FaceGenerator_normal(nn.Module): + def __init__(self, input_nc=4, output_nc=3, num_downs=8, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(Feature2FaceGenerator_normal, self).__init__() + # construct unet structure + unet_block = ResUnetSkipConnectionBlock_small(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, + innermost=True) + + for i in range(num_downs - 5): + unet_block = ResUnetSkipConnectionBlock_small(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, + norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = ResUnetSkipConnectionBlock_small(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock_small(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock_small(ngf, ngf * 2, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock_small(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, + norm_layer=norm_layer) + + self.model = unet_block + + def forward(self, input): + output = self.model(input) + output = torch.tanh(output) # scale to [-1, 1] + + return output + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class ResUnetSkipConnectionBlock_small(nn.Module): + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(ResUnetSkipConnectionBlock_small, self).__init__() + self.outermost = outermost + use_bias = norm_layer == nn.InstanceNorm2d + + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, + stride=2, padding=1, bias=use_bias) + # add two resblock + res_downconv = [ResidualBlock(inner_nc, norm_layer)] + res_upconv = [ResidualBlock(outer_nc, norm_layer)] + + # res_downconv = [ResidualBlock(inner_nc)] + # res_upconv = [ResidualBlock(outer_nc)] + + downrelu = nn.ReLU(True) + uprelu = nn.ReLU(True) + if norm_layer != None: + downnorm = norm_layer(inner_nc) + upnorm = norm_layer(outer_nc) + + if outermost: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + down = [downconv, downrelu] + res_downconv + # up = [uprelu, upsample, upconv, upnorm] + up = [upsample, upconv] + model = down + [submodule] + up + elif innermost: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + down = [downconv, downrelu] + res_downconv + if norm_layer == None: + up = [upsample, upconv, uprelu] + res_upconv + else: + up = [upsample, upconv, upnorm, uprelu] + res_upconv + model = down + up + else: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + if norm_layer == None: + down = [downconv, downrelu] + res_downconv + up = [upsample, upconv, uprelu] + res_upconv + else: + down = [downconv, downnorm, downrelu] + res_downconv + up = [upsample, upconv, upnorm, uprelu] + res_upconv + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + return torch.cat([x, self.model(x)], 1) + + + +class Feature2FaceGenerator_large(nn.Module): + def __init__(self, input_nc=4, output_nc=3, num_downs=8, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(Feature2FaceGenerator_large, self).__init__() + # construct unet structure + unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, + innermost=True) + + for i in range(num_downs - 5): + unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, + norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, + norm_layer=norm_layer) + + self.model = unet_block + + def forward(self, input): + output = self.model(input) + output = torch.tanh(output) # scale to [-1, 1] + + return output + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class ResUnetSkipConnectionBlock(nn.Module): + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(ResUnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + use_bias = norm_layer == nn.InstanceNorm2d + + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, + stride=2, padding=1, bias=use_bias) + # add two resblock + res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)] + res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)] + + # res_downconv = [ResidualBlock(inner_nc)] + # res_upconv = [ResidualBlock(outer_nc)] + + downrelu = nn.ReLU(True) + uprelu = nn.ReLU(True) + if norm_layer != None: + downnorm = norm_layer(inner_nc) + upnorm = norm_layer(outer_nc) + + if outermost: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + down = [downconv, downrelu] + res_downconv + # up = [uprelu, upsample, upconv, upnorm] + up = [upsample, upconv] + model = down + [submodule] + up + elif innermost: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + down = [downconv, downrelu] + res_downconv + if norm_layer == None: + up = [upsample, upconv, uprelu] + res_upconv + else: + up = [upsample, upconv, upnorm, uprelu] + res_upconv + model = down + up + else: + upsample = nn.Upsample(scale_factor=2, mode='nearest') + upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) + if norm_layer == None: + down = [downconv, downrelu] + res_downconv + up = [upsample, upconv, uprelu] + res_upconv + else: + down = [downconv, downnorm, downrelu] + res_downconv + up = [upsample, upconv, upnorm, uprelu] + res_upconv + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + return torch.cat([x, self.model(x)], 1) + + +# UNet with residual blocks +class ResidualBlock(nn.Module): + def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d): + super(ResidualBlock, self).__init__() + self.relu = nn.ReLU(True) + if norm_layer == None: + # hard to converge with out batch or instance norm + self.block = nn.Sequential( + nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), + nn.ReLU(inplace=True), + nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), + ) + else: + self.block = nn.Sequential( + nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), + norm_layer(in_features) + ) + + def forward(self, x): + residual = x + out = self.block(x) + out += residual + out = self.relu(out) + return out + # return self.relu(x + self.block(x)) + + + +class Feature2FaceGenerator_Unet(nn.Module): + def __init__(self, input_nc=4, output_nc=3, num_downs=8, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(Feature2FaceGenerator_Unet, self).__init__() + + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + + def forward(self, input): + output = self.model(input) + + return output + + + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + + + +class MultiscaleDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, + num_D=3, getIntermFeat=False): + super(MultiscaleDiscriminator, self).__init__() + self.num_D = num_D + self.n_layers = n_layers + self.getIntermFeat = getIntermFeat + ndf_max = 64 + + for i in range(num_D): + netD = NLayerDiscriminator(input_nc, min(ndf_max, ndf*(2**(num_D-1-i))), n_layers, getIntermFeat) + if getIntermFeat: + for j in range(n_layers+2): + setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) + else: + setattr(self, 'layer'+str(i), netD.model) + + self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) + + def singleD_forward(self, model, input): + if self.getIntermFeat: + result = [input] + for i in range(len(model)): + result.append(model[i](result[-1])) + return result[1:] + else: + return [model(input)] + + def forward(self, input): + num_D = self.num_D + result = [] + input_downsampled = input + for i in range(num_D): + if self.getIntermFeat: + model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] + else: + model = getattr(self, 'layer'+str(num_D-1-i)) + result.append(self.singleD_forward(model, input_downsampled)) + if i != (num_D-1): + input_downsampled = self.downsample(input_downsampled) + return result + + + +# Defines the PatchGAN discriminator with the specified arguments. +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, getIntermFeat=False): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), + nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + nn.BatchNorm2d(nf), + nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + nn.BatchNorm2d(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[1:] + else: + return self.model(input) + + + + + + diff --git a/talkingface/model/audio_driven_talkingface/__init__.py b/talkingface/model/audio_driven_talkingface/__init__.py index 04a35f33..1280d7ce 100644 --- a/talkingface/model/audio_driven_talkingface/__init__.py +++ b/talkingface/model/audio_driven_talkingface/__init__.py @@ -1 +1,2 @@ +from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits import * from talkingface.model.audio_driven_talkingface.wav2lip import Wav2Lip, SyncNet_color \ No newline at end of file diff --git a/talkingface/model/audio_driven_talkingface/live_speech_portraits.py b/talkingface/model/audio_driven_talkingface/live_speech_portraits.py new file mode 100644 index 00000000..f8fa2fd2 --- /dev/null +++ b/talkingface/model/audio_driven_talkingface/live_speech_portraits.py @@ -0,0 +1,1121 @@ +import os +from os.path import join +import torch +import torch.nn as nn +import torch.nn.functional as F +import scipy.io as sio +from skimage.io import imread +from torch.optim import lr_scheduler +from torch.nn import init +import functools +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence +import numpy as np +import librosa +from tqdm import tqdm +from cog import BasePredictor, Input, Path +from talkingface.utils import set_color +from logging import getLogger +from collections import OrderedDict +from torch.cuda.amp import autocast as autocast +import albumentations +from abc import ABC, abstractmethod +from .LiveSpeechPortraits import networks +from talkingface.model.abstract_talkingface import AbstractTalkingFace +import talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.audio2feature as audio2feature +import talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.audio2headpose as audio2headpose +import talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.feature2face_G as feature2face_G +import talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.feature2face_D as feature2face_D +from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.losses import GMMLogLoss, Sample_GMM, GANLoss, MaskedL1Loss, VGGLoss +from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits import create_model +from talkingface.data.dataset.LiveSpeechPortraits import create_dataset +from talkingface.utils.live_speech_portraits.options.test_audio2feature_options import TestOptions as FeatureOptions +from talkingface.utils.live_speech_portraits.options.test_audio2headpose_options import TestOptions as HeadposeOptions +from talkingface.utils.live_speech_portraits.options.test_feature2face_options import TestOptions as RenderOptions +from talkingface.utils.live_speech_portraits import utils +import talkingface.utils.live_speech_portraits.util.util as util +from talkingface.utils.live_speech_portraits.util.visualizer import Visualizer +from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.networks import APC_encoder + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this function, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + # get device name: CPU or GPU + # if self.gpu_ids == '-1': + # self.device = torch.device('cpu') + # self.gpu_ids = opt.gpu_ids == [] + # else: + # self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if len(self.gpu_ids) > 0 else torch.device('cpu') + + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + # torch speed up training + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + self.load_networks(opt.load_epoch) + self.print_networks(opt.verbose) + + def train(self): + """Make models train mode during train time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.train(mode=True) + + def eval(self): + """Make models eval mode during test time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self): + """ Return image paths that are used to load current data""" + return self.image_paths + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate = %.7f' % lr) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float( + getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch, train_info=None): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_%s.pkl' % (epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, name) + torch.save(net.state_dict(), save_path) + if train_info is not None: + epoch, epoch_iter = train_info + iter_path = os.path.join(self.save_dir, 'iter.txt') + np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + + for name in self.model_names: + if isinstance(name, str): + if epoch[-3:] == 'pkl': + load_path = epoch + else: + load_filename = '%s_%s.pkl' % (epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, name) + # if isinstance(net, torch.nn.DataParallel): + # net = net.module + if os.path.exists(load_path): + state_dict = torch.load(load_path, map_location=str(self.device)) + if self.device == torch.device('cpu'): + for key in list(state_dict.keys()): + state_dict[key[7:]] = state_dict.pop(key) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + print('loading the model from %s' % load_path) + net.load_state_dict(state_dict, strict=False) + else: + print('No model weight file:', load_path, 'initialize model without pre-trained weights.') + if self.isTrain == False: + raise ValueError( + 'We are now in inference process, no pre-trained model found! Check the model checkpoint!') + + # if isinstance(net, torch.nn.DataParallel): + # net = net.module + + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + + # state_dict = torch.load(load_path, map_location=str(self.device)) + # if hasattr(state_dict, '_metadata'): + # del state_dict._metadata + # + # # patch InstanceNorm checkpoints prior to 0.4 + # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + # net.load_state_dict(state_dict) + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad + + +class Audio2HeadposeModel(BaseModel): + def __init__(self, opt): + """Initialize the Audio2Headpose class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + + # specify the models you want to save to the disk. The training/test scripts will call and + # define networks + self.model_names = ['Audio2Headpose'] + if opt.feature_decoder == 'WaveNet': + self.Audio2Headpose = networks.init_net(audio2headpose.Audio2Headpose(opt), init_type='normal', + init_gain=0.02, gpu_ids=opt.gpu_ids) + elif opt.feature_decoder == 'LSTM': + self.Audio2Headpose = networks.init_net(audio2headpose.Audio2Headpose_LSTM(opt), init_type='normal', + init_gain=0.02, gpu_ids=opt.gpu_ids) + + # define only during training time + if self.isTrain: + # losses + self.criterion_GMM = GMMLogLoss(opt.A2H_GMM_ncenter, opt.A2H_GMM_ndim, opt.A2H_GMM_sigma_min).to( + self.device) + self.criterion_L2 = nn.MSELoss().cuda() + # optimizer + self.optimizer = torch.optim.Adam([{'params': self.Audio2Headpose.parameters(), + 'initial_lr': opt.lr}], lr=opt.lr, betas=(0.9, 0.99)) + + self.optimizers.append(self.optimizer) + + if opt.continue_train: + self.resume_training() + + def resume_training(self): + opt = self.opt + ### if continue training, recover previous states + print('Resuming from epoch %s ' % (opt.load_epoch)) + # change epoch count & update schedule settings + opt.epoch_count = int(opt.load_epoch) + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + # print lerning rate + lr = self.optimizers[0].param_groups[0]['lr'] + print('update learning rate: {} -> {}'.format(opt.lr, lr)) + + def set_input(self, data, data_info=None): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + """ + if self.opt.feature_decoder == 'WaveNet': + self.headpose_audio_feats, self.history_headpose, self.target_headpose = data + self.headpose_audio_feats = self.headpose_audio_feats.to(self.device) + self.history_headpose = self.history_headpose.to(self.device) + self.target_headpose = self.target_headpose.to(self.device) + elif self.opt.feature_decoder == 'LSTM': + self.headpose_audio_feats, self.target_headpose = data + self.headpose_audio_feats = self.headpose_audio_feats.to(self.device) + self.target_headpose = self.target_headpose.to(self.device) + + def forward(self): + ''' + Args: + history_landmarks: [b, T, ndim] + audio_features: [b, 1, nfeas, nwins] + Returns: + preds: [b, T, output_channels] + ''' + + if self.opt.audio_windows == 2: + bs, item_len, ndim = self.headpose_audio_feats.shape + self.headpose_audio_feats = self.headpose_audio_feats.reshape(bs, -1, ndim * 2) + else: + bs, item_len, _, ndim = self.headpose_audio_feats.shape + if self.opt.feature_decoder == 'WaveNet': + self.preds_headpose = self.Audio2Headpose.forward(self.history_headpose, self.headpose_audio_feats) + elif self.opt.feature_decoder == 'LSTM': + self.preds_headpose = self.Audio2Headpose.forward(self.headpose_audio_feats) + + def calculate_loss(self): + """ calculate loss in detail, only forward pass included""" + if self.opt.loss == 'GMM': + self.loss_GMM = self.criterion_GMM(self.preds_headpose, self.target_headpose) + self.loss = self.loss_GMM + elif self.opt.loss == 'L2': + self.loss_L2 = self.criterion_L2(self.preds_headpose, self.target_headpose) + self.loss = self.loss_L2 + + if not self.opt.smooth_loss == 0: + mu_gen = Sample_GMM(self.preds_headpose, + self.Audio2Headpose.module.WaveNet.ncenter, + self.Audio2Headpose.module.WaveNet.ndim, + sigma_scale=0) + self.smooth_loss = (mu_gen[:, 2:] + self.target_headpose[:, :-2] - 2 * self.target_headpose[:, 1:-1]).mean( + dim=2).abs().mean() + self.loss += self.smooth_loss * self.opt.smooth_loss + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + self.calculate_loss() + self.loss.backward() + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.optimizer.zero_grad() # clear optimizer parameters grad + self.forward() # forward pass + self.backward() # calculate loss and gradients + self.optimizer.step() # update gradients + + def validate(self): + """ validate process """ + with torch.no_grad(): + self.forward() + self.calculate_loss() + + def generate_sequences(self, audio_feats, pre_headpose, fill_zero=True, sigma_scale=0.0, opt=[]): + ''' generate landmark sequences given audio and a initialized landmark. + Note that the audio input should have the same sample rate as the training. + Args: + audio_sequences: [n,], in numpy + init_landmarks: [npts, 2], in numpy + sample_rate: audio sample rate, should be same as training process. + method(str): optional, how to generate the sequence, indeed it is the + loss function during training process. Options are 'L2' or 'GMM'. + Reutrns: + landmark_sequences: [T, npts, 2] predition landmark sequences + ''' + + frame_future = opt.frame_future + audio_feats = audio_feats.reshape(-1, 512 * 2) + nframe = audio_feats.shape[0] - frame_future + pred_headpose = np.zeros([nframe, opt.A2H_GMM_ndim]) + + if opt.feature_decoder == 'WaveNet': + # fill zero or not + if fill_zero == True: + # headpose + audio_feats_insert = np.repeat(audio_feats[0], opt.A2H_receptive_field - 1) + audio_feats_insert = audio_feats_insert.reshape(-1, opt.A2H_receptive_field - 1).T + audio_feats = np.concatenate([audio_feats_insert, audio_feats]) + # history headpose + history_headpose = np.repeat(pre_headpose, opt.A2H_receptive_field) + history_headpose = history_headpose.reshape(-1, opt.A2H_receptive_field).T + history_headpose = torch.from_numpy(history_headpose).unsqueeze(0).float().to(self.device) + infer_start = 0 + else: + return None + + # evaluate mode + self.Audio2Headpose.eval() + + with torch.no_grad(): + for i in tqdm(range(infer_start, nframe), desc='generating headpose'): + history_start = i - infer_start + input_audio_feats = audio_feats[ + history_start + frame_future: history_start + frame_future + opt.A2H_receptive_field] + input_audio_feats = torch.from_numpy(input_audio_feats).unsqueeze(0).float().to(self.device) + + if self.opt.feature_decoder == 'WaveNet': + preds = self.Audio2Headpose.forward(history_headpose, input_audio_feats) + elif self.opt.feature_decoder == 'LSTM': + preds = self.Audio2Headpose.forward(input_audio_feats) + + if opt.loss == 'GMM': + pred_data = Sample_GMM(preds, opt.A2H_GMM_ncenter, opt.A2H_GMM_ndim, sigma_scale=sigma_scale) + elif opt.loss == 'L2': + pred_data = preds + + # get predictions + pred_headpose[i] = pred_data[0, 0].cpu().detach().numpy() + history_headpose = torch.cat((history_headpose[:, 1:, :], pred_data.to(self.device)), + dim=1) # add in time-axis + + return pred_headpose + + elif opt.feature_decoder == 'LSTM': + self.Audio2Headpose.eval() + with torch.no_grad(): + input = torch.from_numpy(audio_feats).unsqueeze(0).float().to(self.device) + preds = self.Audio2Headpose.forward(input) + if opt.loss == 'GMM': + pred_data = Sample_GMM(preds, opt.A2H_GMM_ncenter, opt.A2H_GMM_ndim, sigma_scale=sigma_scale) + elif opt.loss == 'L2': + pred_data = preds + # get predictions + pred_headpose = pred_data[0].cpu().detach().numpy() + + return pred_headpose + + +class Feature2FaceModel(BaseModel): + def __init__(self, opt): + """Initialize the Feature2Face class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + self.Tensor = torch.cuda.FloatTensor + # specify the models you want to save to the disk. The training/test scripts will call and + # define networks + self.model_names = ['Feature2Face_G'] + self.Feature2Face_G = networks.init_net(feature2face_G.Feature2Face_G(opt), init_type='normal', init_gain=0.02, + gpu_ids=opt.gpu_ids) + if self.isTrain: + if not opt.no_discriminator: + self.model_names += ['Feature2Face_D'] + from . import feature2face_D + self.Feature2Face_D = networks.init_net(feature2face_D.Feature2Face_D(opt), init_type='normal', + init_gain=0.02, gpu_ids=opt.gpu_ids) + + # define only during training time + if self.isTrain: + # define losses names + self.loss_names_G = ['L1', 'VGG', 'Style', 'loss_G_GAN', 'loss_G_FM'] + # criterion + self.criterionMaskL1 = MaskedL1Loss().cuda() + self.criterionL1 = nn.L1Loss().cuda() + self.criterionVGG = VGGLoss.cuda() + self.criterionFlow = nn.L1Loss().cuda() + + # initialize optimizer G + if opt.TTUR: + beta1, beta2 = 0, 0.9 + lr = opt.lr / 2 + else: + beta1, beta2 = opt.beta1, 0.999 + lr = opt.lr + self.optimizer_G = torch.optim.Adam([{'params': self.Feature2Face_G.module.parameters(), + 'initial_lr': lr}], + lr=lr, + betas=(beta1, beta2)) + self.optimizers.append(self.optimizer_G) + + # fp16 training + if opt.fp16: + self.scaler = torch.cuda.amp.GradScaler() + + # discriminator setting + if not opt.no_discriminator: + self.criterionGAN = GANLoss(opt.gan_mode, tensor=self.Tensor) + self.loss_names_D = ['D_real', 'D_fake'] + # initialize optimizer D + if opt.TTUR: + beta1, beta2 = 0, 0.9 + lr = opt.lr * 2 + else: + beta1, beta2 = opt.beta1, 0.999 + lr = opt.lr + self.optimizer_D = torch.optim.Adam([{'params': self.Feature2Face_D.module.netD.parameters(), + 'initial_lr': lr}], + lr=lr, + betas=(beta1, beta2)) + self.optimizers.append(self.optimizer_D) + + def init_paras(self, dataset): + opt = self.opt + iter_path = os.path.join(self.save_dir, 'iter.txt') + start_epoch, epoch_iter = 1, 0 + ### if continue training, recover previous states + if opt.continue_train: + if os.path.exists(iter_path): + start_epoch, epoch_iter = np.loadtxt(iter_path, delimiter=',', dtype=int) + print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) + # change epoch count & update schedule settings + opt.epoch_count = start_epoch + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + # print lerning rate + lr = self.optimizers[0].param_groups[0]['lr'] + print('update learning rate: {} -> {}'.format(opt.lr, lr)) + else: + print('not found training log, hence training from epoch 1') + # change training sequence length + # if start_epoch > opt.nepochs_step: + # dataset.dataset.update_training_batch((start_epoch-1)//opt.nepochs_step) + + total_steps = (start_epoch - 1) * len(dataset) + epoch_iter + total_steps = total_steps // opt.print_freq * opt.print_freq + + return start_epoch, opt.print_freq, total_steps, epoch_iter + + def set_input(self, data, data_info=None): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + """ + self.feature_map, self.cand_image, self.tgt_image, self.facial_mask = \ + data['feature_map'], data['cand_image'], data['tgt_image'], data['weight_mask'] + self.feature_map = self.feature_map.to(self.device) + self.cand_image = self.cand_image.to(self.device) + self.tgt_image = self.tgt_image.to(self.device) + + # self.facial_mask = self.facial_mask.to(self.device) + + def forward(self): + ''' forward pass for feature2Face + ''' + self.input_feature_maps = torch.cat([self.feature_map, self.cand_image], dim=1) + self.fake_pred = self.Feature2Face_G(self.input_feature_maps) + + def backward_G(self): + """Calculate GAN and other loss for the generator""" + # GAN loss + real_AB = torch.cat((self.input_feature_maps, self.tgt_image), dim=1) + fake_AB = torch.cat((self.input_feature_maps, self.fake_pred), dim=1) + pred_real = self.Feature2Face_D(real_AB) + pred_fake = self.Feature2Face_D(fake_AB) + loss_G_GAN = self.criterionGAN(pred_fake, True) + # L1, vgg, style loss + loss_l1 = self.criterionL1(self.fake_pred, self.tgt_image) * self.opt.lambda_L1 + # loss_maskL1 = self.criterionMaskL1(self.fake_pred, self.tgt_image, self.facial_mask * self.opt.lambda_mask) + loss_vgg, loss_style = self.criterionVGG(self.fake_pred, self.tgt_image, style=True) + loss_vgg = torch.mean(loss_vgg) * self.opt.lambda_feat + loss_style = torch.mean(loss_style) * self.opt.lambda_feat + # feature matching loss + loss_FM = self.compute_FeatureMatching_loss(pred_fake, pred_real) + + # combine loss and calculate gradients + + if not self.opt.fp16: + self.loss_G = loss_G_GAN + loss_l1 + loss_vgg + loss_style + loss_FM # + loss_maskL1 + self.loss_G.backward() + else: + with autocast(): + self.loss_G = loss_G_GAN + loss_l1 + loss_vgg + loss_style + loss_FM # + loss_maskL1 + self.scaler.scale(self.loss_G).backward() + + self.loss_dict = {**self.loss_dict, + **dict(zip(self.loss_names_G, [loss_l1, loss_vgg, loss_style, loss_G_GAN, loss_FM]))} + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # GAN loss + real_AB = torch.cat((self.input_feature_maps, self.tgt_image), dim=1) + fake_AB = torch.cat((self.input_feature_maps, self.fake_pred), dim=1) + pred_real = self.Feature2Face_D(real_AB) + pred_fake = self.Feature2Face_D(fake_AB.detach()) + with autocast(): + loss_D_real = self.criterionGAN(pred_real, True) * 2 + loss_D_fake = self.criterionGAN(pred_fake, False) + + self.loss_D = (loss_D_fake + loss_D_real) * 0.5 + + self.loss_dict = dict(zip(self.loss_names_D, [loss_D_real, loss_D_fake])) + + if not self.opt.fp16: + self.loss_D.backward() + else: + self.scaler.scale(self.loss_D).backward() + + def compute_FeatureMatching_loss(self, pred_fake, pred_real): + # GAN feature matching loss + loss_FM = torch.zeros(1).cuda() + feat_weights = 4.0 / (self.opt.n_layers_D + 1) + D_weights = 1.0 / self.opt.num_D + for i in range(min(len(pred_fake), self.opt.num_D)): + for j in range(len(pred_fake[i])): + loss_FM += D_weights * feat_weights * \ + self.criterionL1(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat + + return loss_FM + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + # only train single image generation + ## forward + self.forward() + # update D + self.set_requires_grad(self.Feature2Face_D, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + if not self.opt.fp16: + self.backward_D() # calculate gradients for D + self.optimizer_D.step() # update D's weights + else: + with autocast(): + self.backward_D() + self.scaler.step(self.optimizer_D) + + # update G + self.set_requires_grad(self.Feature2Face_D, False) # D requires no gradients when optimizing G + self.optimizer_G.zero_grad() # set G's gradients to zero + if not self.opt.fp16: + self.backward_G() # calculate graidents for G + self.optimizer_G.step() # udpate G's weights + else: + with autocast(): + self.backward_G() + self.scaler.step(self.optimizer_G) + self.scaler.update() + + def inference(self, feature_map, cand_image): + """ inference process """ + with torch.no_grad(): + if cand_image == None: + input_feature_maps = feature_map + else: + input_feature_maps = torch.cat([feature_map, cand_image], dim=1) + if not self.opt.fp16: + fake_pred = self.Feature2Face_G(input_feature_maps) + else: + with autocast(): + fake_pred = self.Feature2Face_G(input_feature_maps) + return fake_pred + +class Audio2FeatureModel(BaseModel): + def __init__(self, opt): + """Initialize the Audio2Feature class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + + # specify the models you want to save to the disk. The training/test scripts will call and + # define networks + self.model_names = ['Audio2Feature'] + self.Audio2Feature = networks.init_net(audio2feature.Audio2Feature(opt), init_type='normal', init_gain=0.02, + gpu_ids=opt.gpu_ids) + + # define only during training time + if self.isTrain: + # losses + self.featureL2loss = torch.nn.MSELoss().to(self.device) + # optimizer + self.optimizer = torch.optim.Adam([{'params': self.Audio2Feature.parameters(), + 'initial_lr': opt.lr}], lr=opt.lr, betas=(0.9, 0.99)) + + self.optimizers.append(self.optimizer) + + if opt.continue_train: + self.resume_training() + + def resume_training(self): + opt = self.opt + ### if continue training, recover previous states + print('Resuming from epoch %s ' % (opt.load_epoch)) + # change epoch count & update schedule settings + opt.epoch_count = int(opt.load_epoch) + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + # print lerning rate + lr = self.optimizers[0].param_groups[0]['lr'] + print('update learning rate: {} -> {}'.format(opt.lr, lr)) + + def set_input(self, data, data_info=None): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + """ + + self.audio_feats, self.target_info = data + # b, item_length, mel_channels, width = self.audio_feats.shape + self.audio_feats = self.audio_feats.to(self.device) + self.target_info = self.target_info.to(self.device) + + # gaussian noise + + # if self.opt.gaussian_noise: + # self.audio_feats = self.opt.gaussian_noise_scale * torch.randn(self.audio_feats.shape).cuda() + # self.target_info += self.opt.gaussian_noise_scale * torch.randn(self.target_info.shape).cuda() + + def forward(self): + ''' + Args: + history_landmarks: [b, T, ndim] + audio_features: [b, 1, nfeas, nwins] + Returns: + preds: [b, T, output_channels] + ''' + self.preds = self.Audio2Feature.forward(self.audio_feats) + + def calculate_loss(self): + """ calculate loss in detail, only forward pass included""" + if self.opt.loss == 'GMM': + b, T, _ = self.target_info.shape + self.loss_GMM = self.criterion_GMM(self.preds, self.target_info) + self.loss = self.loss_GMM + + elif self.opt.loss == 'L2': + frame_future = self.opt.frame_future + if not frame_future == 0: + self.loss = self.featureL2loss(self.preds[:, frame_future:], self.target_info[:, :-frame_future]) * 1000 + else: + self.loss = self.featureL2loss(self.preds, self.target_info) * 1000 + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + self.calculate_loss() + self.loss.backward() + + def optimize_parameters(self): + """Update network weights; it will be called in every training iteration.""" + self.optimizer.zero_grad() # clear optimizer parameters grad + self.forward() # forward pass + self.backward() # calculate loss and gradients + self.optimizer.step() # update gradients + + def validate(self): + """ validate process """ + with torch.no_grad(): + self.forward() + self.calculate_loss() + + def generate_sequences(self, audio_feats, sample_rate=16000, fps=60, fill_zero=True, opt=[]): + ''' generate landmark sequences given audio and a initialized landmark. + Note that the audio input should have the same sample rate as the training. + Args: + audio_sequences: [n,], in numpy + init_landmarks: [npts, 2], in numpy + sample_rate: audio sample rate, should be same as training process. + method(str): optional, how to generate the sequence, indeed it is the + loss function during training process. Options are 'L2' or 'GMM'. + Reutrns: + landmark_sequences: [T, npts, 2] predition landmark sequences + ''' + + frame_future = opt.frame_future + nframe = int(audio_feats.shape[0] / 2) + + if not frame_future == 0: + audio_feats_insert = np.repeat(audio_feats[-1], 2 * (frame_future)).reshape(-1, 2 * (frame_future)).T + audio_feats = np.concatenate([audio_feats, audio_feats_insert]) + + # evaluate mode + self.Audio2Feature.eval() + + with torch.no_grad(): + input = torch.from_numpy(audio_feats).unsqueeze(0).float().to(self.device) + preds = self.Audio2Feature.forward(input) + + # drop first frame future results + if not frame_future == 0: + preds = preds[0, frame_future:].cpu().detach().numpy() + else: + preds = preds[0, :].cpu().detach().numpy() + + assert preds.shape[0] == nframe + + return preds + + +class live_speech_portraits(AbstractTalkingFace): + + def __init__(self, opt): + self.logger = getLogger() + super(live_speech_portraits, self).__init__() + self.opt = opt + return + + def calculate_loss(self, interaction): + """ calculate loss in detail, only forward pass included""" + if self.opt.loss == 'GMM': + self.loss_GMM = self.criterion_GMM(self.preds_headpose, self.target_headpose) + self.loss = self.loss_GMM + elif self.opt.loss == 'L2': + self.loss_L2 = self.criterion_L2(self.preds_headpose, self.target_headpose) + self.loss = self.loss_L2 + + if not self.opt.smooth_loss == 0: + mu_gen = Sample_GMM(self.preds_headpose, + self.Audio2Headpose.module.WaveNet.ncenter, + self.Audio2Headpose.module.WaveNet.ndim, + sigma_scale=0) + self.smooth_loss = (mu_gen[:, 2:] + self.target_headpose[:, :-2] - 2 * self.target_headpose[:, 1:-1]).mean( + dim=2).abs().mean() + self.loss += self.smooth_loss * self.opt.smooth_loss + + def predict(self, driving_audio: Path = Input(description='driving audio, if the file is more than 20 seconds, only the first 20 seconds will be processed for video generation'), + talking_head: str = Input(description="choose a talking head", choices=['May', 'Obama1', 'Obama2', 'Nadella', 'McStay'], default='May') + ) -> Path: + ############################### I/O Settings ############################## + device = self.config['device'] + optID = self.config['dataset_params']['root'].split('/')[-1] + driving_audio = self.config['driving_audio_path'] + data_root = self.config['dataset_params']['root'] + + # create the results folder + audio_name = driving_audio.split('/')[-1].split('.')[-2] + save_root = join('./results/', optID, audio_name) + if not os.path.exists(save_root): + os.makedirs(save_root) + + + ############################ Hyper Parameters ############################# + h, w, sr, FPS = 512, 512, 16000, 60 + mouth_indices = np.concatenate([np.arange(4, 11), np.arange(46, 64)]) + eye_brow_indices = [27, 65, 28, 68, 29, 67, 30, 66, 31, 72, 32, 69, 33, 70, 34, 71] + eye_brow_indices = np.array(eye_brow_indices, np.int32) + + + ############################ Pre-defined Data ############################# + mean_pts3d = np.load(join(data_root, 'mean_pts3d.npy')) + fit_data = np.load(self.config['dataset_params']['fit_data_path']) + pts3d = np.load(self.config['dataset_params']['pts3d_path']) - mean_pts3d + trans = fit_data['trans'][:, :, 0].astype(np.float32) + mean_translation = trans.mean(axis=0) + candidate_eye_brow = pts3d[10:, eye_brow_indices] + std_mean_pts3d = np.load(self.config['dataset_params']['pts3d_path']).mean(axis=0) + + # candidates images + img_candidates = [] + for j in range(4): + output = imread(join(data_root, 'candidates', f'normalized_full_{j}.jpg')) + output = albumentations.pytorch.transforms.ToTensor(normalize={'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5)})(image=output)['image'] + img_candidates.append(output) + img_candidates = torch.cat(img_candidates).unsqueeze(0).to(device) + + # shoulders + shoulders = np.load(join(data_root, 'normalized_shoulder_points.npy')) + shoulder3D = np.load(join(data_root, 'shoulder_points3D.npy'))[1] + ref_trans = trans[1] + + # camera matrix, we always use training set intrinsic parameters. + camera = utils.camera() + camera_intrinsic = np.load(join(data_root, 'camera_intrinsic.npy')).astype(np.float32) + APC_feat_database = np.load(join(data_root, 'APC_feature_base.npy')) + + # load reconstruction data + scale = sio.loadmat(join(data_root, 'id_scale.mat'))['scale'][0, 0] + # Audio2Mel_torch = audio_funcs.Audio2Mel(n_fft=512, hop_length=int(16000/120), win_length=int(16000/60), sampling_rate=16000, + # n_mel_channels=80, mel_fmin=90, mel_fmax=7600.0).to(device) + + + + ########################### Experiment Settings ########################### + #### user config + use_LLE = self.config['model_params']['APC']['use_LLE'] + Knear = self.config['model_params']['APC']['Knear'] + LLE_percent = self.config['model_params']['APC']['LLE_percent'] + headpose_sigma = self.config['model_params']['Headpose']['sigma'] + Feat_smooth_sigma = self.config['model_params']['Audio2Mouth']['smooth'] + Head_smooth_sigma = self.config['model_params']['Headpose']['smooth'] + Feat_center_smooth_sigma, Head_center_smooth_sigma = 0, 0 + AMP_method = self.config['model_params']['Audio2Mouth']['AMP'][0] + Feat_AMPs = self.config['model_params']['Audio2Mouth']['AMP'][1:] + rot_AMP, trans_AMP = self.config['model_params']['Headpose']['AMP'] + shoulder_AMP = self.config['model_params']['Headpose']['shoulder_AMP'] + save_feature_maps = self.config['model_params']['Image2Image']['save_input'] + + #### common settings + Featopt = FeatureOptions().parse() + Headopt = HeadposeOptions().parse() + Renderopt = RenderOptions().parse() + Featopt.load_epoch = self.config['model_params']['Audio2Mouth']['ckp_path'] + Headopt.load_epoch = self.config['model_params']['Headpose']['ckp_path'] + Renderopt.dataroot = self.config['dataset_params']['root'] + Renderopt.load_epoch = self.config['model_params']['Image2Image']['ckp_path'] + Renderopt.size = self.config['model_params']['Image2Image']['size'] + ## GPU or CPU + if device == 'cpu': + Featopt.gpu_ids = Headopt.gpu_ids = Renderopt.gpu_ids = [] + + + + ############################# Load Models ################################# + print('---------- Loading Model: APC-------------') + APC_model = APC_encoder(self.config['model_params']['APC']['mel_dim'], + self.config['model_params']['APC']['hidden_size'], + self.config['model_params']['APC']['num_layers'], + self.config['model_params']['APC']['residual']) + APC_model.load_state_dict(torch.load(self.config['model_params']['APC']['ckp_path']), strict=False) + if device == 'cuda': + APC_model.cuda() + APC_model.eval() + print('---------- Loading Model: {} -------------'.format(Featopt.task)) + Audio2Feature = create_model(Featopt) + Audio2Feature.setup(Featopt) + Audio2Feature.eval() + print('---------- Loading Model: {} -------------'.format(Headopt.task)) + Audio2Headpose = create_model(Headopt) + Audio2Headpose.setup(Headopt) + Audio2Headpose.eval() + if Headopt.feature_decoder == 'WaveNet': + if device == 'cuda': + Headopt.A2H_receptive_field = Audio2Headpose.Audio2Headpose.module.WaveNet.receptive_field + else: + Headopt.A2H_receptive_field = Audio2Headpose.Audio2Headpose.WaveNet.receptive_field + print('---------- Loading Model: {} -------------'.format(Renderopt.task)) + facedataset = create_dataset(Renderopt) + Feature2Face = create_model(Renderopt) + Feature2Face.setup(Renderopt) + Feature2Face.eval() + visualizer = Visualizer(Renderopt) + + + ############################## Inference ################################## + print('Processing audio: {} ...'.format(audio_name)) + # read audio + audio, _ = librosa.load(driving_audio, sr=sr) + total_frames = np.int32(audio.shape[0] / sr * FPS) + + + #### 1. compute APC features + print('1. Computing APC features...') + mel80 = utils.compute_mel_one_sequence(audio, device=device) + mel_nframe = mel80.shape[0] + with torch.no_grad(): + length = torch.Tensor([mel_nframe]) + mel80_torch = torch.from_numpy(mel80.astype(np.float32)).to(device).unsqueeze(0) + hidden_reps = APC_model.forward(mel80_torch, length)[0] # [mel_nframe, 512] + hidden_reps = hidden_reps.cpu().numpy() + audio_feats = hidden_reps + + + #### 2. manifold projection + if use_LLE: + print('2. Manifold projection...') + ind = utils.KNN_with_torch(audio_feats, APC_feat_database, K=Knear) + weights, feat_fuse = utils.compute_LLE_projection_all_frame(audio_feats, APC_feat_database, ind, + audio_feats.shape[0]) + audio_feats = audio_feats * (1 - LLE_percent) + feat_fuse * LLE_percent + + + + #### 3. Audio2Mouth + print('3. Audio2Mouth inference...') + pred_Feat = Audio2Feature.generate_sequences(audio_feats, sr, FPS, fill_zero=True, opt=Featopt) + + + + + #### 4. Audio2Headpose + print('4. Headpose inference...') + # set history headposes as zero + pre_headpose = np.zeros(Headopt.A2H_wavenet_input_channels, np.float32) + pred_Head = Audio2Headpose.generate_sequences(audio_feats, pre_headpose, fill_zero=True, sigma_scale=0.3, + opt=Headopt) + + + + #### 5. Post-Processing + print('5. Post-processing...') + nframe = min(pred_Feat.shape[0], pred_Head.shape[0]) + pred_pts3d = np.zeros([nframe, 73, 3]) + pred_pts3d[:, mouth_indices] = pred_Feat.reshape(-1, 25, 3)[:nframe] + + + + ## mouth + pred_pts3d = utils.landmark_smooth_3d(pred_pts3d, Feat_smooth_sigma, area='only_mouth') + pred_pts3d = utils.mouth_pts_AMP(pred_pts3d, True, AMP_method, Feat_AMPs) + pred_pts3d = pred_pts3d + mean_pts3d + pred_pts3d = utils.solve_intersect_mouth(pred_pts3d) # solve intersect lips if exist + + + + ## headpose + pred_Head[:, 0:3] *= rot_AMP + pred_Head[:, 3:6] *= trans_AMP + pred_headpose = utils.headpose_smooth(pred_Head[:, :6], Head_smooth_sigma).astype(np.float32) + pred_headpose[:, 3:] += mean_translation + pred_headpose[:, 0] += 180 + + + + ## compute projected landmarks + pred_landmarks = np.zeros([nframe, 73, 2], dtype=np.float32) + final_pts3d = np.zeros([nframe, 73, 3], dtype=np.float32) + final_pts3d[:] = std_mean_pts3d.copy() + final_pts3d[:, 46:64] = pred_pts3d[:nframe, 46:64] + for k in tqdm(range(nframe)): + ind = k % candidate_eye_brow.shape[0] + final_pts3d[k, eye_brow_indices] = candidate_eye_brow[ind] + mean_pts3d[eye_brow_indices] + pred_landmarks[k], _, _ = utils.project_landmarks(camera_intrinsic, camera.relative_rotation, + camera.relative_translation, scale, + pred_headpose[k], final_pts3d[k]) + + + ## Upper Body Motion + pred_shoulders = np.zeros([nframe, 18, 2], dtype=np.float32) + pred_shoulders3D = np.zeros([nframe, 18, 3], dtype=np.float32) + for k in range(nframe): + diff_trans = pred_headpose[k][3:] - ref_trans + pred_shoulders3D[k] = shoulder3D + diff_trans * shoulder_AMP + # project + project = camera_intrinsic.dot(pred_shoulders3D[k].T) + project[:2, :] /= project[2, :] # divide z + pred_shoulders[k] = project[:2, :].T + + + + #### 6. Image2Image translation & Save resuls + print('6. Image2Image translation & Saving results...') + for ind in tqdm(range(0, nframe), desc='Image2Image translation inference'): + # feature_map: [input_nc, h, w] + current_pred_feature_map = facedataset.dataset.get_data_test_mode(pred_landmarks[ind], + pred_shoulders[ind], + facedataset.dataset.image_pad) + input_feature_maps = current_pred_feature_map.unsqueeze(0).to(device) + pred_fake = Feature2Face.inference(input_feature_maps, img_candidates) + # save results + visual_list = [('pred', util.tensor2im(pred_fake[0]))] + if save_feature_maps: + visual_list += [('input', np.uint8(current_pred_feature_map[0].cpu().numpy() * 255))] + visuals = OrderedDict(visual_list) + visualizer.save_images(save_root, visuals, str(ind + 1)) + + ## make videos + # generate corresponding audio, reused for all results + tmp_audio_path = join(save_root, 'tmp.wav') + tmp_audio_clip = audio[: np.int32(nframe * sr / FPS)] + librosa.output.write_wav(tmp_audio_path, tmp_audio_clip, sr) + + def generate_batch(self): + return + + def other_parameter(self): + if hasattr(self, "other_parameter_name"): + return {key: getattr(self, key) for key in self.other_parameter_name} + return dict() + + def load_other_parameter(self, para): + if para is None: + return + for key, value in para.items(): + setattr(self, key, value) + + def __str__(self): + """ + Model prints with number of trainable parameters + """ + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + params = sum([np.prod(p.size()) for p in model_parameters]) + return ( + super().__str__() + + set_color("\nTrainable parameters", "blue") + + f": {params}" + ) diff --git a/talkingface/model/text_to_speech/__init__.py b/talkingface/model/text_to_speech/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface/model/voice_conversion/__init__.py b/talkingface/model/voice_conversion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/talkingface/properties/model/live_speech_portraits.yaml b/talkingface/properties/model/live_speech_portraits.yaml new file mode 100644 index 00000000..ef3b2b78 --- /dev/null +++ b/talkingface/properties/model/live_speech_portraits.yaml @@ -0,0 +1,37 @@ +model_params: + APC: + ckp_path: './checkpoints/live_speech_portraits/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './checkpoints/live_speech_portraits/McStay/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './checkpoints/live_speech_portraits/McStay/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './checkpoints/live_speech_portraits/McStay/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + +#我自己加的 +checkpoint_sub_dir: "/live_speech_portraits" # 和overall.yaml里checkpoint_dir拼起来作为最终目录 +temp_sub_dir: "/live_speech_portraits" # 和overall.yaml里temp_dir拼起来作为最终目录 +driving_audio_path: './checkpoints/live_speech_portraits/Input/00083.wav' #驱动音频路径 +save_intermediates: 0 #是否存储中间文件 + +dataset_params: + root: './checkpoints/live_speech_portraits/McStay' + fit_data_path: './checkpoints/live_speech_portraits/McStay/3d_fit_data.npz' + pts3d_path: './checkpoints/live_speech_portraits/McStay/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface/properties/model/live_speech_portraits/May.yaml b/talkingface/properties/model/live_speech_portraits/May.yaml new file mode 100644 index 00000000..d28c4e76 --- /dev/null +++ b/talkingface/properties/model/live_speech_portraits/May.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/May/checkpoints/Audio2Feature.pkl' + smooth: 1.5 + AMP: ['XYZ', 2, 2, 2] # method, x, y, z + Headpose: + ckp_path: './data/May/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [1, 0.5] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/May/checkpoints/Feature2Face.pkl' + size: 'large' + save_input: 1 + + +dataset_params: + root: './data/May/' + fit_data_path: './data/May/3d_fit_data.npz' + pts3d_path: './data/May/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface/properties/model/live_speech_portraits/McStay.yaml b/talkingface/properties/model/live_speech_portraits/McStay.yaml new file mode 100644 index 00000000..25d9db17 --- /dev/null +++ b/talkingface/properties/model/live_speech_portraits/McStay.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/McStay/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/McStay/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/McStay/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/McStay/' + fit_data_path: './data/McStay/3d_fit_data.npz' + pts3d_path: './data/McStay/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface/properties/model/live_speech_portraits/Nadella.yaml b/talkingface/properties/model/live_speech_portraits/Nadella.yaml new file mode 100644 index 00000000..66f33573 --- /dev/null +++ b/talkingface/properties/model/live_speech_portraits/Nadella.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/Nadella/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/Nadella/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [5, 10] # rot, trans + AMP: [0.5, 0.5] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/Nadella/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/Nadella/' + fit_data_path: './data/Nadella/3d_fit_data.npz' + pts3d_path: './data/Nadella/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface/properties/model/live_speech_portraits/Obama1.yaml b/talkingface/properties/model/live_speech_portraits/Obama1.yaml new file mode 100644 index 00000000..ce414876 --- /dev/null +++ b/talkingface/properties/model/live_speech_portraits/Obama1.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/Obama1/checkpoints/Audio2Feature.pkl' + smooth: 1 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/Obama1/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [2, 8] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/Obama1/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/Obama1/' + fit_data_path: './data/Obama1/3d_fit_data.npz' + pts3d_path: './data/Obama1/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface/properties/model/live_speech_portraits/Obama2.yaml b/talkingface/properties/model/live_speech_portraits/Obama2.yaml new file mode 100644 index 00000000..6d543151 --- /dev/null +++ b/talkingface/properties/model/live_speech_portraits/Obama2.yaml @@ -0,0 +1,32 @@ +model_params: + APC: + ckp_path: './data/APC_epoch_160.model' + mel_dim: 80 + hidden_size: 512 + num_layers: 3 + residual: false + use_LLE: 1 + Knear: 10 + LLE_percent: 1 + Audio2Mouth: + ckp_path: './data/Obama2/checkpoints/Audio2Feature.pkl' + smooth: 2 + AMP: ['XYZ', 1.5, 1.5, 1.5] # method, x, y, z + Headpose: + ckp_path: './data/Obama2/checkpoints/Audio2Headpose.pkl' + sigma: 0.3 + smooth: [3, 10] # rot, trans + AMP: [1, 1] # rot, trans + shoulder_AMP: 0.5 + Image2Image: + ckp_path: './data/Obama2/checkpoints/Feature2Face.pkl' + size: 'normal' + save_input: 1 + + +dataset_params: + root: './data/Obama2/' + fit_data_path: './data/Obama2/3d_fit_data.npz' + pts3d_path: './data/Obama2/tracked3D_normalized_pts_fix_contour.npy' + + diff --git a/talkingface/properties/overall.yaml b/talkingface/properties/overall.yaml index 81ac51ae..09391f78 100644 --- a/talkingface/properties/overall.yaml +++ b/talkingface/properties/overall.yaml @@ -20,7 +20,7 @@ stopping_step: 10 # (int) The threshold for validation-based early weight_decay: 0.0 # (float) The weight decay value (L2 penalty) for optimizers. saved: True resume: True -train: True +train: False # Evaluation Settings metrics: ["LSE", "SSIM"] diff --git a/talkingface/quick_start/quick_start.py b/talkingface/quick_start/quick_start.py index 3ff2e889..45b32d01 100644 --- a/talkingface/quick_start/quick_start.py +++ b/talkingface/quick_start/quick_start.py @@ -1,5 +1,6 @@ import logging import sys +import torch import torch.distributed as dist from collections.abc import MutableMapping from logging import getLogger @@ -64,6 +65,7 @@ def run_talkingface( config_file_list=config_file_list, config_dict=config_dict, ) + init_seed(config["seed"], config["reproducibility"]) init_logger(config) logger = getLogger() @@ -82,11 +84,9 @@ def run_talkingface( val_data_loader = data_utils.DataLoader( val_dataset, batch_size=config["batch_size"], shuffle=False ) - # load model model = get_model(config["model"])(config).to(config["device"]) logger.info(model) - trainer = get_trainer(config["model"])(config, model) # model training diff --git a/talkingface/trainer/__init__.py b/talkingface/trainer/__init__.py index 819cacbf..45860cff 100644 --- a/talkingface/trainer/__init__.py +++ b/talkingface/trainer/__init__.py @@ -1 +1,2 @@ -from talkingface.trainer.trainer import * \ No newline at end of file +from talkingface.trainer.trainer import * +from talkingface.trainer.live_speech_portraitsTrainer import * \ No newline at end of file diff --git a/talkingface/trainer/live_speech_portraitsTrainer.py b/talkingface/trainer/live_speech_portraitsTrainer.py new file mode 100644 index 00000000..e52b1b13 --- /dev/null +++ b/talkingface/trainer/live_speech_portraitsTrainer.py @@ -0,0 +1,309 @@ +import os +import subprocess +from os.path import join +from tqdm import tqdm +import numpy as np +import torch +from collections import OrderedDict +import librosa +from skimage.io import imread +import cv2 +import scipy.io as sio +import argparse +import yaml +import albumentations +import albumentations.pytorch +from pathlib import Path +from .trainer import AbstractTrainer +from talkingface.utils.live_speech_portraits import utils +from talkingface.utils.live_speech_portraits.options.test_audio2feature_options import TestOptions as FeatureOptions +from talkingface.utils.live_speech_portraits.options.test_audio2headpose_options import TestOptions as HeadposeOptions +from talkingface.utils.live_speech_portraits.options.test_feature2face_options import TestOptions as RenderOptions +from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits import create_model +from talkingface.model.audio_driven_talkingface.LiveSpeechPortraits.networks import APC_encoder +from talkingface.utils.live_speech_portraits.util.visualizer import Visualizer +from talkingface.data.dataset.LiveSpeechPortraits import create_dataset +import talkingface.utils.live_speech_portraits.util.util as util + +class live_speech_portraitsTrainer(AbstractTrainer): + def __init__(self, config, model): + self.config = config + self.model = model + + def fit(self, train_data): + r"""Train the model based on the train data.""" + raise NotImplementedError("Method [next] should be implemented.") + + def evaluate(self, model_file): + r"""Evaluate the model based on the eval data.""" + ############################### I/O Settings ############################## + device = self.config['device'] + optID = self.config['dataset_params']['root'].split('/')[-1] + driving_audio = self.config['driving_audio_path'] + data_root = self.config['dataset_params']['root'] + + # create the results folder + audio_name = driving_audio.split('/')[-1].split('.')[-2] + save_root = join('./results/', optID, audio_name) + if not os.path.exists(save_root): + os.makedirs(save_root) + + + ############################ Hyper Parameters ############################# + h, w, sr, FPS = 512, 512, 16000, 60 + mouth_indices = np.concatenate([np.arange(4, 11), np.arange(46, 64)]) + eye_brow_indices = [27, 65, 28, 68, 29, 67, 30, 66, 31, 72, 32, 69, 33, 70, 34, 71] + eye_brow_indices = np.array(eye_brow_indices, np.int32) + + + ############################ Pre-defined Data ############################# + mean_pts3d = np.load(join(data_root, 'mean_pts3d.npy')) + fit_data = np.load(self.config['dataset_params']['fit_data_path']) + pts3d = np.load(self.config['dataset_params']['pts3d_path']) - mean_pts3d + trans = fit_data['trans'][:, :, 0].astype(np.float32) + mean_translation = trans.mean(axis=0) + candidate_eye_brow = pts3d[10:, eye_brow_indices] + std_mean_pts3d = np.load(self.config['dataset_params']['pts3d_path']).mean(axis=0) + + # candidates images + img_candidates = [] + for j in range(4): + output = imread(join(data_root, 'candidates', f'normalized_full_{j}.jpg')) + output = albumentations.pytorch.transforms.ToTensor(normalize={'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5)})(image=output)['image'] + img_candidates.append(output) + img_candidates = torch.cat(img_candidates).unsqueeze(0).to(device) + + # shoulders + shoulders = np.load(join(data_root, 'normalized_shoulder_points.npy')) + shoulder3D = np.load(join(data_root, 'shoulder_points3D.npy'))[1] + ref_trans = trans[1] + + # camera matrix, we always use training set intrinsic parameters. + camera = utils.camera() + camera_intrinsic = np.load(join(data_root, 'camera_intrinsic.npy')).astype(np.float32) + APC_feat_database = np.load(join(data_root, 'APC_feature_base.npy')) + + # load reconstruction data + scale = sio.loadmat(join(data_root, 'id_scale.mat'))['scale'][0, 0] + # Audio2Mel_torch = audio_funcs.Audio2Mel(n_fft=512, hop_length=int(16000/120), win_length=int(16000/60), sampling_rate=16000, + # n_mel_channels=80, mel_fmin=90, mel_fmax=7600.0).to(device) + + + + ########################### Experiment Settings ########################### + #### user config + use_LLE = self.config['model_params']['APC']['use_LLE'] + Knear = self.config['model_params']['APC']['Knear'] + LLE_percent = self.config['model_params']['APC']['LLE_percent'] + headpose_sigma = self.config['model_params']['Headpose']['sigma'] + Feat_smooth_sigma = self.config['model_params']['Audio2Mouth']['smooth'] + Head_smooth_sigma = self.config['model_params']['Headpose']['smooth'] + Feat_center_smooth_sigma, Head_center_smooth_sigma = 0, 0 + AMP_method = self.config['model_params']['Audio2Mouth']['AMP'][0] + Feat_AMPs = self.config['model_params']['Audio2Mouth']['AMP'][1:] + rot_AMP, trans_AMP = self.config['model_params']['Headpose']['AMP'] + shoulder_AMP = self.config['model_params']['Headpose']['shoulder_AMP'] + save_feature_maps = self.config['model_params']['Image2Image']['save_input'] + + #### common settings + Featopt = FeatureOptions().parse() + Headopt = HeadposeOptions().parse() + Renderopt = RenderOptions().parse() + Featopt.load_epoch = self.config['model_params']['Audio2Mouth']['ckp_path'] + Headopt.load_epoch = self.config['model_params']['Headpose']['ckp_path'] + Renderopt.dataroot = self.config['dataset_params']['root'] + Renderopt.load_epoch = self.config['model_params']['Image2Image']['ckp_path'] + Renderopt.size = self.config['model_params']['Image2Image']['size'] + ## GPU or CPU + if device == 'cpu': + Featopt.gpu_ids = Headopt.gpu_ids = Renderopt.gpu_ids = [] + + + + ############################# Load Models ################################# + print('---------- Loading Model: APC-------------') + APC_model = APC_encoder(self.config['model_params']['APC']['mel_dim'], + self.config['model_params']['APC']['hidden_size'], + self.config['model_params']['APC']['num_layers'], + self.config['model_params']['APC']['residual']) + APC_model.load_state_dict(torch.load(self.config['model_params']['APC']['ckp_path']), strict=False) + if device == 'cuda': + APC_model.cuda() + APC_model.eval() + print('---------- Loading Model: {} -------------'.format(Featopt.task)) + Audio2Feature = create_model(Featopt) + Audio2Feature.setup(Featopt) + Audio2Feature.eval() + print('---------- Loading Model: {} -------------'.format(Headopt.task)) + Audio2Headpose = create_model(Headopt) + Audio2Headpose.setup(Headopt) + Audio2Headpose.eval() + if Headopt.feature_decoder == 'WaveNet': + if device == 'cuda': + Headopt.A2H_receptive_field = Audio2Headpose.Audio2Headpose.module.WaveNet.receptive_field + else: + Headopt.A2H_receptive_field = Audio2Headpose.Audio2Headpose.WaveNet.receptive_field + print('---------- Loading Model: {} -------------'.format(Renderopt.task)) + facedataset = create_dataset(Renderopt) + Feature2Face = create_model(Renderopt) + Feature2Face.setup(Renderopt) + Feature2Face.eval() + visualizer = Visualizer(Renderopt) + + + ############################## Inference ################################## + print('Processing audio: {} ...'.format(audio_name)) + # read audio + audio, _ = librosa.load(driving_audio, sr=sr) + total_frames = np.int32(audio.shape[0] / sr * FPS) + + + #### 1. compute APC features + print('1. Computing APC features...') + mel80 = utils.compute_mel_one_sequence(audio, device=device) + mel_nframe = mel80.shape[0] + with torch.no_grad(): + length = torch.Tensor([mel_nframe]) + mel80_torch = torch.from_numpy(mel80.astype(np.float32)).to(device).unsqueeze(0) + hidden_reps = APC_model.forward(mel80_torch, length)[0] # [mel_nframe, 512] + hidden_reps = hidden_reps.cpu().numpy() + audio_feats = hidden_reps + + + #### 2. manifold projection + if use_LLE: + print('2. Manifold projection...') + ind = utils.KNN_with_torch(audio_feats, APC_feat_database, K=Knear) + weights, feat_fuse = utils.compute_LLE_projection_all_frame(audio_feats, APC_feat_database, ind, + audio_feats.shape[0]) + audio_feats = audio_feats * (1 - LLE_percent) + feat_fuse * LLE_percent + + + + #### 3. Audio2Mouth + print('3. Audio2Mouth inference...') + pred_Feat = Audio2Feature.generate_sequences(audio_feats, sr, FPS, fill_zero=True, opt=Featopt) + + + + + #### 4. Audio2Headpose + print('4. Headpose inference...') + # set history headposes as zero + pre_headpose = np.zeros(Headopt.A2H_wavenet_input_channels, np.float32) + pred_Head = Audio2Headpose.generate_sequences(audio_feats, pre_headpose, fill_zero=True, sigma_scale=0.3, + opt=Headopt) + + + + #### 5. Post-Processing + print('5. Post-processing...') + nframe = min(pred_Feat.shape[0], pred_Head.shape[0]) + pred_pts3d = np.zeros([nframe, 73, 3]) + pred_pts3d[:, mouth_indices] = pred_Feat.reshape(-1, 25, 3)[:nframe] + + + + ## mouth + pred_pts3d = utils.landmark_smooth_3d(pred_pts3d, Feat_smooth_sigma, area='only_mouth') + pred_pts3d = utils.mouth_pts_AMP(pred_pts3d, True, AMP_method, Feat_AMPs) + pred_pts3d = pred_pts3d + mean_pts3d + pred_pts3d = utils.solve_intersect_mouth(pred_pts3d) # solve intersect lips if exist + + + + ## headpose + pred_Head[:, 0:3] *= rot_AMP + pred_Head[:, 3:6] *= trans_AMP + pred_headpose = utils.headpose_smooth(pred_Head[:, :6], Head_smooth_sigma).astype(np.float32) + pred_headpose[:, 3:] += mean_translation + pred_headpose[:, 0] += 180 + + + + ## compute projected landmarks + pred_landmarks = np.zeros([nframe, 73, 2], dtype=np.float32) + final_pts3d = np.zeros([nframe, 73, 3], dtype=np.float32) + final_pts3d[:] = std_mean_pts3d.copy() + final_pts3d[:, 46:64] = pred_pts3d[:nframe, 46:64] + for k in tqdm(range(nframe)): + ind = k % candidate_eye_brow.shape[0] + final_pts3d[k, eye_brow_indices] = candidate_eye_brow[ind] + mean_pts3d[eye_brow_indices] + pred_landmarks[k], _, _ = utils.project_landmarks(camera_intrinsic, camera.relative_rotation, + camera.relative_translation, scale, + pred_headpose[k], final_pts3d[k]) + + + ## Upper Body Motion + pred_shoulders = np.zeros([nframe, 18, 2], dtype=np.float32) + pred_shoulders3D = np.zeros([nframe, 18, 3], dtype=np.float32) + for k in range(nframe): + diff_trans = pred_headpose[k][3:] - ref_trans + pred_shoulders3D[k] = shoulder3D + diff_trans * shoulder_AMP + # project + project = camera_intrinsic.dot(pred_shoulders3D[k].T) + project[:2, :] /= project[2, :] # divide z + pred_shoulders[k] = project[:2, :].T + + + + #### 6. Image2Image translation & Save resuls + print('6. Image2Image translation & Saving results...') + for ind in tqdm(range(0, nframe), desc='Image2Image translation inference'): + # feature_map: [input_nc, h, w] + current_pred_feature_map = facedataset.dataset.get_data_test_mode(pred_landmarks[ind], + pred_shoulders[ind], + facedataset.dataset.image_pad) + input_feature_maps = current_pred_feature_map.unsqueeze(0).to(device) + pred_fake = Feature2Face.inference(input_feature_maps, img_candidates) + # save results + visual_list = [('pred', util.tensor2im(pred_fake[0]))] + if save_feature_maps: + visual_list += [('input', np.uint8(current_pred_feature_map[0].cpu().numpy() * 255))] + visuals = OrderedDict(visual_list) + visualizer.save_images(save_root, visuals, str(ind + 1)) + + ## make videos + # generate corresponding audio, reused for all results + tmp_audio_path = join(save_root, 'tmp.wav') + tmp_audio_clip = audio[: np.int32(nframe * sr / FPS)] + librosa.output.write_wav(tmp_audio_path, tmp_audio_clip, sr) + + final_path = join(save_root, audio_name + '.avi') +# self.write_video_with_audio(tmp_audio_path, final_path, 'pred_') + fps, fourcc = 60, cv2.VideoWriter_fourcc(*'DIVX') + video_tmp_path = join(save_root, 'tmp.avi') + out = cv2.VideoWriter(video_tmp_path, fourcc, fps, (Renderopt.loadSize, Renderopt.loadSize)) + for j in tqdm(range(nframe), position=0, desc='writing video'): + img = cv2.imread(join(save_root, 'pred_' + str(j + 1) + '.jpg')) + out.write(img) + out.release() + cmd = 'ffmpeg -i "' + video_tmp_path + '" -i "' + tmp_audio_path + '" -codec copy -shortest "' + final_path + '"' + subprocess.call(cmd, shell=True) + os.remove(video_tmp_path) # remove the template video + + feature_maps_path = join(save_root, audio_name + '_feature_maps.avi') +# self.write_video_with_audio(tmp_audio_path, feature_maps_path, 'input_') + fps, fourcc = 60, cv2.VideoWriter_fourcc(*'DIVX') + video_tmp_path = join(save_root, 'tmp.avi') + out = cv2.VideoWriter(video_tmp_path, fourcc, fps, (Renderopt.loadSize, Renderopt.loadSize)) + for j in tqdm(range(nframe), position=0, desc='writing video'): + img = cv2.imread(join(save_root, 'input_' + str(j + 1) + '.jpg')) + out.write(img) + out.release() + cmd = 'ffmpeg -i "' + video_tmp_path + '" -i "' + tmp_audio_path + '" -codec copy -shortest "' + feature_maps_path + '"' + subprocess.call(cmd, shell=True) + os.remove(video_tmp_path) # remove the template video + + + if os.path.exists(tmp_audio_path): + os.remove(tmp_audio_path) + if not self.config['save_intermediates']: + _img_paths = list(map(lambda x: str(x), list(Path(save_root).glob('*.jpg')))) + for i in tqdm(range(len(_img_paths)), desc='deleting intermediate images'): + os.remove(_img_paths[i]) + + print('Finish!') + # raise NotImplementedError("Method [next] should be implemented.") + diff --git a/talkingface/trainer/trainer.py b/talkingface/trainer/trainer.py index 2c34717b..026900f0 100644 --- a/talkingface/trainer/trainer.py +++ b/talkingface/trainer/trainer.py @@ -2,7 +2,7 @@ from logging import getLogger from time import time -import dlib, json, subprocess +# import dlib, json, subprocess import torch.nn.functional as F import glob import numpy as np diff --git a/talkingface/utils/live_speech_portraits/audio_funcs.py b/talkingface/utils/live_speech_portraits/audio_funcs.py new file mode 100644 index 00000000..8ba875c4 --- /dev/null +++ b/talkingface/utils/live_speech_portraits/audio_funcs.py @@ -0,0 +1,269 @@ +import os +import os.path +import math +# import sox +# import pyworld as pw +import torch +import torch.utils.data +import numpy as np +import librosa + +""" +useage +fft = Audio2Mel().cuda() +# audio shape is B x 1 x T, the normalized mel shape is B x D x T +mel = fft(audio) +""" +from librosa.filters import mel as librosa_mel_fn +import torch.nn.functional as F + + +class Audio2Mel(torch.nn.Module): + def __init__( + self, + n_fft=512, + hop_length=256, + win_length=1024, + sampling_rate=16000, + n_mel_channels=80, + mel_fmin=90, + mel_fmax=7600.0, + ): + super(Audio2Mel, self).__init__() + ############################################## + # FFT Parameters # + ############################################## + window = torch.hann_window(win_length).float() + mel_basis = librosa_mel_fn( + sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + self.register_buffer("window", window) + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.sampling_rate = sampling_rate + self.n_mel_channels = n_mel_channels + self.min_mel = math.log(1e-5) + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + + """ + input audio signal (-1,1): B x 1 x T + output mel signal: B x D x T', T' is a reduction of T + """ + + def forward(self, audio, normalize=True): + p = (self.n_fft - self.hop_length) // 2 + audio = F.pad(audio, (p, p), "reflect").squeeze(1) + fft = torch.stft( + audio, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=False, + return_complex=False + ) + real_part, imag_part = fft.unbind(-1) + magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2) + mel_output = torch.matmul(self.mel_basis, magnitude) + log_mel_spec = torch.log(torch.clamp(mel_output, min=1e-5)) + + # normalize to the range [0,1] + if normalize: + log_mel_spec = (log_mel_spec - self.min_mel) / -self.min_mel + return log_mel_spec + + def mel_to_audio(self, mel): + mel = torch.exp(mel * (-self.min_mel) + self.min_mel) ** 2 + mel_np = mel.cpu().numpy() + audio = librosa.feature.inverse.mel_to_audio(mel_np, sr=self.sampling_rate, n_fft=self.n_fft, + hop_length=self.hop_length, win_length=self.win_length, + window='hann', center=False, + pad_mode='reflect', power=2.0, n_iter=32, fmin=self.mel_fmin, + fmax=self.mel_fmax) + return audio + + """ + here we will get per frame energy to replace mc0 in the corresponding prosody representation + the audio is already in the gpu card for accerelate the computation speed + input audio signal: B x 1 x T + output energy: B x 1 x T' + """ + + def get_energy(self, audio, normalize=True): + # B x 1 x T + p = (self.n_fft - self.hop_length) // 2 + audio_new = F.pad(audio, (p, p), "reflect").squeeze(1) + # audio_new = audio.squeeze(1) + audio_fold = audio_new.unfold(1, self.win_length, self.hop_length) + audio_energy = torch.sqrt(torch.mean(audio_fold ** 2, dim=-1)) + audio_energy = torch.log(torch.clamp(audio_energy, min=1e-5)) + if normalize: + audio_energy = (audio_energy - self.min_mel) / -self.min_mel + return audio_energy + + # we can get the energy of mels here, B*D*T + def get_energy_mel(self, mels, normalize=True): + m = mels.exp().mean(dim=1) + audio_energy = torch.log(m) + # audio_energy = torch.log(torch.clamp(m,min=1e-5)) + # if normalize: + # audio_energy = (audio_energy - self.min_mel) / -self.min_mel + return audio_energy + + +def mu_law_encoding(data, mu=255): + '''encode the original audio via mu-law companding and mu-bits quantization + ''' + # mu-law companding + mu_x = np.sign(data) * np.log(1 + mu * np.abs(data)) / np.log(mu + 1) + # mu-bits quantization from [-1, 1] to [0, mu] + mu_x = (mu_x + 1) / 2 * mu + 0.5 + return mu_x.astype(np.int32) + + +# %timeit mu_x = mu_law_encoding(x, 255) 305 µs ± 554 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each) + + +def mu_law_decoding(data, mu=255): + '''inverse the mu-law compressed and quantized data. + ''' + # dequantization + y = 2 * (data.astype(np.float32) / mu) - 1 + # inverse mu-law companding + x = np.sign(y) * (1.0 / mu) * ((1.0 + mu) ** abs(y) - 1.0) + return x + + +## audio augmentation +def inject_gaussian_noise(data, noise_factor, use_torch=False): + ''' inject random gaussian noise (mean=0, std=1) to audio clip + In my test, a reasonable factor region could be [0, 0.01] + larger will be too large and smaller could be ignored. + Args: + data: [n,] original audio sequence + noise_factor(float): scaled factor + use_torch(bool): optional, if use_torch=True, input data and implementation will + be torch methods. + Returns: + augmented_data: [n,] noised audio clip + + ''' + if use_torch == False: + augmented_data = data + noise_factor * np.random.normal(0, 1, len(data)) + # Cast back to same data type + augmented_data = augmented_data.astype(type(data[0])) + # use torch + else: + augmented_data = data + noise_factor * torch.randn(1).cuda() + + return augmented_data + + +# pitch shifting +def pitch_shifting(data, sampling_rate=48000, factor=5): + ''' shift the audio pitch. + ''' + # Permissible factor values = -5 <= x <= 5 + pitch_factor = np.random.rand(1) * 2 * factor - factor + return librosa.effects.pitch_shift(data, sampling_rate, pitch_factor) + + +def speed_change(data, landmark=None): + ''' change the speed of input audio. Note that we return the speed_rate to + change the speed of landmarks or videos. + Args: + data: [n,] audio clip + landmark: [m, pts, 2] aligned landmarks with audio if existed. + ''' + # Permissible factor values = 0.7 <= x <= 1.3 (higher is faster) + # resulted audio length: np.round(n/rate) + speed_rate = np.random.uniform(0.7, 1.3) + # only augment audio + if landmark == None: + return librosa.effects.time_stretch(data, speed_rate), speed_rate + else: + # n_after = np.round(data.shape[0]/speed_rate) + pass + + +def prepare_noises(scp_file, root=None, sampline_rate=None, ignore_class=None): + noises = [] + print('Loading augmentation noises...') + with open(scp_file, 'r') as fp: + for line in fp.readlines(): + line = line.rstrip('\n') + if ignore_class is not None and ignore_class in line: + continue + + noise, sr = librosa.load(os.path.join(root, line), sr=sampline_rate) + noises.append(noise) + print('Augmentation noises loaded!') + return noises, sr + + +def add_gauss_noise(wav, noise_std=0.03, max_wav_value=1.0): + if isinstance(wav, np.ndarray): + wav = torch.tensor(wav.copy()) + + real_std = np.random.random() * noise_std + wav_new = wav.float() / max_wav_value + torch.randn(wav.size()) * real_std + wav_new = wav_new * max_wav_value + wav_new = wav_new.clamp_(-max_wav_value, max_wav_value) + + return wav_new.float().numpy() + + +def add_background_noise(wav, noises, min_snr=2, max_snr=15): + def mix_noise(wav, noise, scale): + x = wav + scale * noise + x = x.clip(-1, 1) + return x + + def voice_energy(wav): + wav_float = np.copy(wav) + return np.sum(wav_float ** 2) / (wav_float.shape[0] + 1e-5) + + def voice_energy_ratio(wav, noise, target_snr): + wav_eng = voice_energy(wav) + noise_eng = voice_energy(noise) + target_noise_eng = wav_eng / (10 ** (target_snr / 10.0)) + ratio = target_noise_eng / (noise_eng + 1e-5) + return ratio + + total_id = len(noises) + # 0 is no need to generate the noise + idx = np.random.choice(range(0, total_id)) + noise_wav = noises[idx] + if noise_wav.shape[0] > wav.shape[0]: + sel_range_id = np.random.choice(range(0, noise_wav.shape[0] - wav.shape[0])) + n = noise_wav[sel_range_id:sel_range_id + wav.shape[0]] + else: + n = np.zeros(wav.shape[0]) + sel_range_id = np.random.choice(range(0, wav.shape[0] - noise_wav.shape[0] + 1)) + n[sel_range_id:sel_range_id + noise_wav.shape[0]] = noise_wav + # + target_snr = np.random.random() * (max_snr - min_snr) + min_snr + scale = voice_energy_ratio(wav, n, target_snr) + wav_new = mix_noise(wav, n, scale) + return wav_new + + +def noise_augment(wav, wav_noises, gaussian_prob=0.5): + if np.random.random() > gaussian_prob: # add gauss noise + noise_std = np.random.uniform(low=0.001, high=0.02) + aug_wave_data = add_gauss_noise(wav, noise_std=noise_std) + else: # add background noise + aug_wave_data = add_background_noise(wav, wav_noises, min_snr=2, max_snr=15) + + return aug_wave_data + + + + + + + diff --git a/talkingface/utils/live_speech_portraits/options/__init__.py b/talkingface/utils/live_speech_portraits/options/__init__.py new file mode 100644 index 00000000..e7eedebe --- /dev/null +++ b/talkingface/utils/live_speech_portraits/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/talkingface/utils/live_speech_portraits/options/base_options_audio2feature.py b/talkingface/utils/live_speech_portraits/options/base_options_audio2feature.py new file mode 100644 index 00000000..504ed345 --- /dev/null +++ b/talkingface/utils/live_speech_portraits/options/base_options_audio2feature.py @@ -0,0 +1,188 @@ +import argparse +import os +import torch +import numpy as np +from talkingface.utils.live_speech_portraits.util import util +import talkingface.model.audio_driven_talkingface.LiveSpeechPortraits as models + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + ## task + parser.add_argument('--task', type=str, default='Audio2Feature', help='|Audio2Feature|Feature2Face|etc.') + + ## useless parameters 新加的参数 + parser.add_argument('--dataset', type=str, default='default', help='useless, just for dataset=live_speech_portraits') + parser.add_argument('--model', type=str, default='default', help='useless, just for model=live_speech_portraits') + + ## basic parameters + parser.add_argument('--model_name', type=str, default='audio2feature', help='trained model') + parser.add_argument('--dataset_mode', type=str, default='audiovisual', help='chooses how datasets are loaded. [unaligned | aligned | single]') + parser.add_argument('--name', type=str, default='Audio2Feature', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/', help='models are saved here') + + + # dataset parameters + parser.add_argument('--dataset_names', type=str, default='default_name') + parser.add_argument('--dataroot', type=str, default='default_path') + parser.add_argument('--frame_jump_stride', type=int, default=4, help='jump index in audio dataset.') + parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data') + parser.add_argument('--batch_size', type=int, default=32, help='input batch size') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--audio_encoder', type=str, default='APC', help='|CNN|LSTM|APC|NPC|') + parser.add_argument('--feature_decoder', type=str, default='LSTM', help='|WaveNet|LSTM|') + parser.add_argument('--loss', type=str, default='L2', help='|GMM|L2|') + parser.add_argument('--A2L_GMM_ndim', type=int, default=25*3) + parser.add_argument('--sequence_length', type=int, default=240, help='length of training frames in each iteration') + + + # data setting parameters + parser.add_argument('--FPS', type=str, default=60, help='video fps') + parser.add_argument('--sample_rate', type=int, default=16000, help='audio sample rate') + parser.add_argument('--audioRF_history', type=int, default=60, help='audio history receptive field length') + parser.add_argument('--audioRF_future', type=int, default=0, help='audio future receptive field length') + parser.add_argument('--feature_dtype', type=str, default='pts3d', help='|FW|pts3d|') + parser.add_argument('--ispts_norm', type=int, default=1, help='use normalized 3d points.') + parser.add_argument('--use_delta_pts', type=int, default=1, help='whether use delta landmark representation') + parser.add_argument('--frame_future', type=int, default=18) + parser.add_argument('--predict_length', type=int, default=1) + parser.add_argument('--only_mouth', type=int, default=1) + + + # APC parameters + parser.add_argument('--APC_hidden_size', type=int, default=512) + parser.add_argument('--APC_rnn_layers', type=int, default=3) + parser.add_argument("--APC_residual", action="store_true") + parser.add_argument('--APC_frame_history', type=int, default=0) + + + # LSTM parameters + parser.add_argument('--LSTM_hidden_size', type=int, default=256) + parser.add_argument('--LSTM_output_size', type=int, default=80) + parser.add_argument('--LSTM_layers', type=int, default=3) + parser.add_argument('--LSTM_dropout', type=float, default=0) + parser.add_argument("--LSTM_residual", action="store_true") + parser.add_argument('--LSTM_sequence_length', type=int, default=60) + + + # additional parameters + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + + self.initialized = True + return parser + + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + print('opt:', opt) + + # modify model-related parser options + model_name = opt.model_name + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + # save and return the parser + self.parser = parser + return opt + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + if opt.isTrain: + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + # if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(opt.gpu_ids[0]) + + # set datasets + if self.isTrain: + opt.train_dataset_names = np.loadtxt(os.path.join(opt.dataroot, + opt.dataset_names, + opt.train_dataset_names), dtype=np.str).tolist() + if type(opt.train_dataset_names) == str: + opt.train_dataset_names = [opt.train_dataset_names] + opt.validate_dataset_names = np.loadtxt(os.path.join(opt.dataroot, + opt.dataset_names, + opt.validate_dataset_names), dtype=np.str).tolist() + if type(opt.validate_dataset_names) == str: + opt.validate_dataset_names = [opt.validate_dataset_names] + + self.opt = opt + return self.opt + + + + + + + + + + + diff --git a/talkingface/utils/live_speech_portraits/options/base_options_audio2headpose.py b/talkingface/utils/live_speech_portraits/options/base_options_audio2headpose.py new file mode 100644 index 00000000..860ae5cf --- /dev/null +++ b/talkingface/utils/live_speech_portraits/options/base_options_audio2headpose.py @@ -0,0 +1,192 @@ +import argparse +import os +import torch +from talkingface.utils.live_speech_portraits.util import util +import talkingface.model.audio_driven_talkingface.LiveSpeechPortraits as models +import numpy as np + + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + ## task + parser.add_argument('--task', type=str, default='Audio2Headpose', help='|Audio2Feature|Feature2Face|Full|') + + ## useless parameters 新加的参数 + parser.add_argument('--dataset', type=str, default='default', help='useless, just for dataset=live_speech_portraits') + parser.add_argument('--model', type=str, default='default', help='useless, just for model=live_speech_portraits') + + ## basic parameters + parser.add_argument('--model_name', type=str, default='audio2headpose', help='trained model') + parser.add_argument('--dataset_mode', type=str, default='audiovisual', help='chooses how datasets are loaded. [unaligned | aligned | single]') + parser.add_argument('--name', type=str, default='Audio2Headpose', help='name of the experiment. It decides where to store samples and models') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/', help='models are saved here') + + + # data parameters + parser.add_argument('--FPS', type=str, default=60, help='video fps') + parser.add_argument('--sample_rate', type=int, default=16000, help='audio sample rate') + parser.add_argument('--audioRF_history', type=int, default=60, help='audio history receptive field length') + parser.add_argument('--audioRF_future', type=int, default=0, help='audio future receptive field length') + parser.add_argument('--feature_decoder', type=str, default='WaveNet', help='|WaveNet|LSTM|') + parser.add_argument('--loss', type=str, default='GMM', help='|GMM|L2|') + + + # dataset parameters + parser.add_argument('--dataset_names', type=str, default='name', help='chooses how datasets are loaded.') + parser.add_argument('--dataroot', type=str, default='path') + parser.add_argument('--frame_jump_stride', type=int, default=1, help='jump index in audio dataset.') + parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data') + parser.add_argument('--batch_size', type=int, default=32, help='input batch size') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--audio_encoder', type=str, default='APC', help='|CNN|LSTM|APC|') + parser.add_argument('--audiofeature_input_channels', type=int, default=80, help='input channels of audio features') + parser.add_argument('--frame_future', type=int, default=15) + parser.add_argument('--predict_length', type=int, default=5) + parser.add_argument('--audio_windows', type=int, default=2) + parser.add_argument('--time_frame_length', type=int, default=240, help='length of training frames in each iteration') + + + # APC parameters + parser.add_argument('--APC_hidden_size', type=int, default=512) + parser.add_argument('--APC_rnn_layers', type=int, default=3) + parser.add_argument("--APC_residual", action="store_true") + parser.add_argument('--APC_frame_history', type=int, default=60) + + + ## network parameters + # audio2headpose wavenet + parser.add_argument('--A2H_wavenet_residual_layers', type=int, default=7, help='residual layer numbers') + parser.add_argument('--A2H_wavenet_residual_blocks', type=int, default=2, help='residual block numbers') + parser.add_argument('--A2H_wavenet_dilation_channels', type=int, default=128, help='dilation convolution channels') + parser.add_argument('--A2H_wavenet_residual_channels', type=int, default=128, help='residual channels') + parser.add_argument('--A2H_wavenet_skip_channels', type=int, default=256, help='skip channels') + parser.add_argument('--A2H_wavenet_kernel_size', type=int, default=2, help='dilation convolution kernel size') + parser.add_argument('--A2H_wavenet_use_bias', type=bool, default=True, help='whether to use bias in dilation convolution') + parser.add_argument('--A2H_wavenet_cond', type=bool, default=True, help='whether use condition input') + parser.add_argument('--A2H_wavenet_cond_channels', type=int, default=512, help='whether use condition input') + parser.add_argument('--A2H_wavenet_input_channels', type=int, default=12, help='input channels') + parser.add_argument('--A2H_GMM_ncenter', type=int, default=1, help='gaussian distribution numbers, 1 for single gaussian distribution') + parser.add_argument('--A2H_GMM_ndim', type=int, default=12, help='dimension of each gaussian, usually number of pts') + parser.add_argument('--A2H_GMM_sigma_min', type=float, default=0.03, help='minimal gaussian sigma values') + + + # additional parameters + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + parser.add_argument('--sequence_length', type=int, default=240, help='length of training frames in each iteration') + + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model_name + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + + # save and return the parser + self.parser = parser + return opt + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + if opt.isTrain: + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + # if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(opt.gpu_ids[0]) + + # set datasets + if self.isTrain: + opt.train_dataset_names = np.loadtxt(os.path.join(opt.dataroot, + opt.dataset_names, + opt.train_dataset_names), dtype=np.str).tolist() + if type(opt.train_dataset_names) == str: + opt.train_dataset_names = [opt.train_dataset_names] + opt.validate_dataset_names = np.loadtxt(os.path.join(opt.dataroot, + opt.dataset_names, + opt.validate_dataset_names), dtype=np.str).tolist() + if type(opt.validate_dataset_names) == str: + opt.validate_dataset_names = [opt.validate_dataset_names] + + self.opt = opt + return self.opt + + + + + + + + + + + diff --git a/talkingface/utils/live_speech_portraits/options/base_options_feature2face.py b/talkingface/utils/live_speech_portraits/options/base_options_feature2face.py new file mode 100644 index 00000000..27ff7125 --- /dev/null +++ b/talkingface/utils/live_speech_portraits/options/base_options_feature2face.py @@ -0,0 +1,132 @@ +import argparse +import os +from talkingface.utils.live_speech_portraits.util import util +import torch +import numpy as np + +class BaseOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser() + self.initialized = False + + def initialize(self): + ## task + self.parser.add_argument('--task', type=str, default='Feature2Face', help='|Audio2Feature|Feature2Face|Full|') + self.parser.add_argument('--model_name', type=str, default='feature2face', help='chooses which model to use. vid2vid, test') + self.parser.add_argument('--name', type=str, default='TestRender', help='name of the experiment. It decides where to store samples and models') + self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints/', help='models are saved here') + self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + + ## useless parameters 新加的参数 + self.parser.add_argument('--dataset', type=str, default='default', help='useless, just for dataset=live_speech_portraits') + self.parser.add_argument('--model', type=str, default='default', help='useless, just for model=live_speech_portraits') + + # display + self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') + self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display') + self.parser.add_argument('--tf_log', default=True, action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') + + + # input/output size + self.parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + self.parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size') + self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') + self.parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels') + self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') + + + # setting inputs + self.parser.add_argument('--dataset_mode', type=str, default='face', help='chooses how datasets are loaded.') + self.parser.add_argument('--dataroot', type=str, default='./data/') + self.parser.add_argument('--isH5', type=int, default=1, help='whether to use h5py to save dataset') + self.parser.add_argument('--suffix', type=str, default='.jpg', help='image suffix') + self.parser.add_argument('--isMask', type=int, default=0, help='use face mask') + self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + self.parser.add_argument('--num_threads', default=0, type=int, help='# threads for loading data') + self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + self.parser.add_argument('--resize_or_crop', type=str, default='scaleWidth', help='scaling and cropping of images at load time [resize_and_crop|crop|scaledCrop|scaleWidth|scaleWidth_and_crop|scaleWidth_and_scaledCrop|scaleHeight|scaleHeight_and_crop] etc') + self.parser.add_argument('--no_flip', type=int, default=1, help='if specified, do not flip the images for data argumentation') + + + # generator arch + self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') + self.parser.add_argument('--n_downsample_G', type=int, default=8, help='number of downsampling layers in netG') + self.parser.add_argument('--ngf_E', type=int, default=16, help='# of gen filters in first conv layer') + self.parser.add_argument('--n_downsample_E', type=int, default=3, help='number of downsampling layers in Enhancement') + self.parser.add_argument('--n_blocks_E', type=int, default=3, help='number of resnet blocks in Enhancement') + + + # miscellaneous + self.parser.add_argument('--load_pretrain', type=str, default='', help='if specified, load the pretrained model') + self.parser.add_argument('--debug', action='store_true', help='if specified, use small dataset for debug') + self.parser.add_argument('--fp16', type=int, default=0, help='train with AMP') + self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training') + self.parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + + self.initialized = True + + def parse_str(self, ids): + str_ids = ids.split(',') + ids_list = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + ids_list.append(id) + return ids_list + + def parse(self, save=True): + if not self.initialized: + self.initialize() + self.opt, _ = self.parser.parse_known_args() + self.opt.isTrain = self.isTrain # train or test + + self.opt.gpu_ids = self.parse_str(self.opt.gpu_ids) + + # set gpu ids + # if len(self.opt.gpu_ids) > 0: + # torch.cuda.set_device(self.opt.gpu_ids[0]) + + # set datasets + datasets = self.opt.dataset_names.split(',') + self.opt.dataset_names = [] + for name in datasets: + self.opt.dataset_names.append(name) + + if self.isTrain: + self.opt.train_dataset_names = np.loadtxt(os.path.join(self.opt.dataroot, + self.opt.dataset_names[0], + self.opt.train_dataset_names), dtype=np.str).tolist() + if type(self.opt.train_dataset_names) == str: + self.opt.train_dataset_names = [self.opt.train_dataset_names] + self.opt.validate_dataset_names = np.loadtxt(os.path.join(self.opt.dataroot, + self.opt.dataset_names[0], + self.opt.validate_dataset_names), dtype=np.str).tolist() + if type(self.opt.validate_dataset_names) == str: + self.opt.validate_dataset_names = [self.opt.validate_dataset_names] + + else: + test_datasets = self.opt.test_dataset_names.split(',') + self.opt.test_dataset_names = [] + for name in test_datasets: + self.opt.test_dataset_names.append(name) + + + args = vars(self.opt) + + print('------------ Options -------------') + for k, v in sorted(args.items()): + print('%s: %s' % (str(k), str(v))) + print('-------------- End ----------------') + + if self.isTrain: + # save to the disk + expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) + util.mkdirs(expr_dir) + if save: + file_name = os.path.join(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write('------------ Options -------------\n') + for k, v in sorted(args.items()): + opt_file.write('%s: %s\n' % (str(k), str(v))) + opt_file.write('-------------- End ----------------\n') + return self.opt diff --git a/talkingface/utils/live_speech_portraits/options/test_audio2feature_options.py b/talkingface/utils/live_speech_portraits/options/test_audio2feature_options.py new file mode 100644 index 00000000..2478db2d --- /dev/null +++ b/talkingface/utils/live_speech_portraits/options/test_audio2feature_options.py @@ -0,0 +1,20 @@ +from .base_options_audio2feature import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--load_epoch', type=str, default='500', help='which epoch to load? set to latest to use latest cached model') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + # rewrite devalue values + parser.set_defaults(time_frame_length=1) + self.isTrain = False + + return parser diff --git a/talkingface/utils/live_speech_portraits/options/test_audio2headpose_options.py b/talkingface/utils/live_speech_portraits/options/test_audio2headpose_options.py new file mode 100644 index 00000000..6e32746f --- /dev/null +++ b/talkingface/utils/live_speech_portraits/options/test_audio2headpose_options.py @@ -0,0 +1,20 @@ +from .base_options_audio2headpose import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + parser.add_argument('--load_epoch', type=str, default='500', help='which epoch to load? set to latest to use latest cached model') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + # rewrite devalue values + parser.set_defaults(time_frame_length=1) + self.isTrain = False + + return parser \ No newline at end of file diff --git a/talkingface/utils/live_speech_portraits/options/test_feature2face_options.py b/talkingface/utils/live_speech_portraits/options/test_feature2face_options.py new file mode 100644 index 00000000..63e2ea6c --- /dev/null +++ b/talkingface/utils/live_speech_portraits/options/test_feature2face_options.py @@ -0,0 +1,11 @@ +from .base_options_feature2face import BaseOptions + + +class TestOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + self.parser.add_argument('--dataset_names', type=str, default='name', help='chooses test datasets.') + self.parser.add_argument('--test_dataset_names', type=str, default='name', help='chooses validation datasets.') + + self.isTrain = False diff --git a/talkingface/utils/live_speech_portraits/options/train_audio2feature_options.py b/talkingface/utils/live_speech_portraits/options/train_audio2feature_options.py new file mode 100644 index 00000000..0491e38f --- /dev/null +++ b/talkingface/utils/live_speech_portraits/options/train_audio2feature_options.py @@ -0,0 +1,56 @@ +from .base_options_audio2feature import BaseOptions + + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + + + # network saving and loading parameters + parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', default=False, action='store_true', help='continue training: load the latest model') + parser.add_argument('--load_epoch', type=str, default='200', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--epoch_count', type=int, default=0, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--re_transform', type=int, default=0, help='re-transform landmarks') + + + # training parameters + parser.add_argument('--train_dataset_names', type=str, default='train_list.txt', help='chooses validation datasets.') + parser.add_argument('--validate_dataset_names', type=str, default='val_list.txt', help='chooses validation datasets.') + parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate for adam') + parser.add_argument('--gamma', type=float, default=0.2, help='step learning rate gamma') + parser.add_argument('--lr_decay_iters', type=int, default=250, help='multiply by a gamma every lr_decay_iters iterations') + parser.add_argument('--n_epochs_decay', type=int, default=250, help='number of epochs to linearly decay learning rate to zero') + parser.add_argument('--validate_epoch', type=int, default=50, help='validate model every some epochs, 0 for not validate during training') + parser.add_argument('--loss_smooth_weight', type=float, default=0, help='smooth loss weight, 0 for not use smooth loss') + parser.add_argument('--optimizer', type=str, default='AdamW', help='Adam, AdamW, RMSprop') + + + # data augmentations + parser.add_argument('--gaussian_noise', type=int, default=1, help='whether add gaussian noise to input & groundtruth features') + parser.add_argument('--gaussian_noise_scale', type=float, default=0.01, help='gaussian noise scale') + + + self.isTrain = True + return parser + + + + + + + + + + + + diff --git a/talkingface/utils/live_speech_portraits/options/train_audio2headpose_options.py b/talkingface/utils/live_speech_portraits/options/train_audio2headpose_options.py new file mode 100644 index 00000000..27339c81 --- /dev/null +++ b/talkingface/utils/live_speech_portraits/options/train_audio2headpose_options.py @@ -0,0 +1,45 @@ +from .base_options_audio2headpose import BaseOptions + + +class TrainOptions(BaseOptions): + """This class includes training options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) + + + # network saving and loading parameters + parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') + parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') + parser.add_argument('--continue_train', default=False, action='store_true', help='continue training: load the latest model') + parser.add_argument('--load_epoch', type=str, default='0', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--epoch_count', type=int, default=0, help='the starting epoch count, we save the model by , +, ...') + parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + parser.add_argument('--re_transform', type=int, default=0, help='re-transform landmarks') + + + # training parameters + parser.add_argument('--smooth_loss', type=int, default=0, help='use smooth loss weight, 0 for not use') + parser.add_argument('--train_dataset_names', type=str, default='train_list.txt', help='chooses validation datasets.') + parser.add_argument('--validate_dataset_names', type=str, default='val_list.txt', help='chooses validation datasets.') + parser.add_argument('--n_epochs', type=int, default=500, help='number of epochs') + parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') + parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate for adam') + parser.add_argument('--gamma', type=float, default=0.2, help='step learning rate gamma') + parser.add_argument('--lr_decay_iters', type=int, default=250, help='multiply by a gamma every lr_decay_iters iterations') + parser.add_argument('--n_epochs_decay', type=int, default=250, help='number of epochs to linearly decay learning rate to zero') + parser.add_argument('--validate_epoch', type=int, default=50, help='validate model every some epochs, 0 for not validate during training') + parser.add_argument('--loss_smooth_weight', type=float, default=0, help='smooth loss weight, 0 for not use smooth loss') + parser.add_argument('--optimizer', type=str, default='AdamW', help='Adam, AdamW, RMSprop') + + + # data augmentations + parser.add_argument('--gaussian_noise', type=int, default=1, help='whether add gaussian noise to input & groundtruth features') + parser.add_argument('--gaussian_noise_scale', type=float, default=0.01, help='gaussian noise scale') + + + self.isTrain = True + return parser \ No newline at end of file diff --git a/talkingface/utils/live_speech_portraits/options/train_feature2face_options.py b/talkingface/utils/live_speech_portraits/options/train_feature2face_options.py new file mode 100644 index 00000000..25516c8d --- /dev/null +++ b/talkingface/utils/live_speech_portraits/options/train_feature2face_options.py @@ -0,0 +1,63 @@ +from .base_options_feature2face import BaseOptions + + +class TrainOptions(BaseOptions): + def initialize(self): + BaseOptions.initialize(self) + ## dataset settings + self.parser.add_argument('--dataset_names', type=str, default='name', help='chooses how datasets are loaded.') + self.parser.add_argument('--train_dataset_names', type=str, default='train_list.txt') + self.parser.add_argument('--validate_dataset_names', type=str, default='val_list.txt') + + + ## training flags + self.parser.add_argument('--display_freq', type=int, default=10, help='frequency of showing training results on screen(iterations)') + self.parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console(epochs)') + self.parser.add_argument('--save_latest_freq', type=int, default=100, help='frequency of to save the latest results(iterations)') + self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') + self.parser.add_argument('--continue_train', default=True, action='store_true', help='continue training: load the latest model') + self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') + self.parser.add_argument('--load_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + self.parser.add_argument('--n_epochs_warm_up', type=int, default=5, help='number of epochs warm up') + self.parser.add_argument('--n_epochs', type=int, default=10, help='number of epochs') + self.parser.add_argument('--n_epochs_decay', type=int, default=10, help='number of epochs to linearly decay learning rate to zero') + self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') + self.parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') + self.parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') + self.parser.add_argument('--lr_decay_iters', type=int, default=900, help='multiply by a gamma every lr_decay_iters iterations') + self.parser.add_argument('--lr_decay_gamma', type=float, default=0.25, help='multiply by a gamma every lr_decay_iters iterations') + self.parser.add_argument('--TTUR', action='store_true', help='Use TTUR training scheme') + self.parser.add_argument('--gan_mode', type=str, default='ls', help='(ls|original|hinge)') + self.parser.add_argument('--pool_size', type=int, default=1, help='the size of image buffer that stores previously generated images') + self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') + self.parser.add_argument('--frame_jump', type=int, default=1, help='jump frame for training, 1 for not jump') + self.parser.add_argument('--epoch_count', type=int, default=0, help='the starting epoch count, we save the model by , +, ...') + self.parser.add_argument('--seq_max_len', type=int, default=120, help='maximum sequence clip frames sent to network per iteration') + + + # for discriminators + self.parser.add_argument('--no_discriminator', type=int, default=0, help='not use discriminator') + self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') + self.parser.add_argument('--num_D', type=int, default=2, help='number of patch scales in each discriminator') + self.parser.add_argument('--n_layers_D', type=int, default=3, help='number of layers in discriminator') + self.parser.add_argument('--no_vgg', action='store_true', help='do not use VGG feature matching loss') + self.parser.add_argument('--no_ganFeat', action='store_true', help='do not match discriminator features') + self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching') + self.parser.add_argument('--sparse_D', action='store_true', help='use sparse temporal discriminators to save memory') + + + # for temporal + self.parser.add_argument('--lambda_T', type=float, default=10.0, help='weight for temporal loss') + self.parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for temporal loss') + self.parser.add_argument('--lambda_F', type=float, default=10.0, help='weight for flow loss') + self.parser.add_argument('--lambda_mask', type=float, default=500.0, help='weight for mask l1 loss') + self.parser.add_argument('--n_frames_D', type=int, default=3, help='number of frames to feed into temporal discriminator') + self.parser.add_argument('--n_scales_temporal', type=int, default=2, help='number of temporal scales in the temporal discriminator') + self.parser.add_argument('--n_frames_per_gpu', type=int, default=1, help='the number of frames to load into one GPU at a time. only 1 is supported now') + self.parser.add_argument('--max_frames_backpropagate', type=int, default=1, help='max number of frames to backpropagate') + self.parser.add_argument('--max_t_step', type=int, default=1, help='max spacing between neighboring sampled frames. If greater than 1, the network may randomly skip frames during training.') + self.parser.add_argument('--n_frames_total', type=int, default=12, help='the overall number of frames in a sequence to train with') + self.parser.add_argument('--nepochs_step', type=int, default=5, help='how many epochs do we change training sequence length again') + self.parser.add_argument('--nepochs_fix_global', type=int, default=0, help='if specified, only train the finest spatial layer for the given iterations') + + self.isTrain = True diff --git a/talkingface/utils/live_speech_portraits/util/flow_viz.py b/talkingface/utils/live_speech_portraits/util/flow_viz.py new file mode 100644 index 00000000..dcee65e8 --- /dev/null +++ b/talkingface/utils/live_speech_portraits/util/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/talkingface/utils/live_speech_portraits/util/get_data.py b/talkingface/utils/live_speech_portraits/util/get_data.py new file mode 100644 index 00000000..97edc3ce --- /dev/null +++ b/talkingface/utils/live_speech_portraits/util/get_data.py @@ -0,0 +1,110 @@ +from __future__ import print_function +import os +import tarfile +import requests +from warnings import warn +from zipfile import ZipFile +from bs4 import BeautifulSoup +from os.path import abspath, isdir, join, basename + + +class GetData(object): + """A Python script for downloading CycleGAN or pix2pix datasets. + + Parameters: + technique (str) -- One of: 'cyclegan' or 'pix2pix'. + verbose (bool) -- If True, print additional information. + + Examples: + >>> from util.get_data import GetData + >>> gd = GetData(technique='cyclegan') + >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. + + Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' + and 'scripts/download_cyclegan_model.sh'. + """ + + def __init__(self, technique='cyclegan', verbose=True): + url_dict = { + 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', + 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' + } + self.url = url_dict.get(technique.lower()) + self._verbose = verbose + + def _print(self, text): + if self._verbose: + print(text) + + @staticmethod + def _get_options(r): + soup = BeautifulSoup(r.text, 'lxml') + options = [h.text for h in soup.find_all('a', href=True) + if h.text.endswith(('.zip', 'tar.gz'))] + return options + + def _present_options(self): + r = requests.get(self.url) + options = self._get_options(r) + print('Options:\n') + for i, o in enumerate(options): + print("{0}: {1}".format(i, o)) + choice = input("\nPlease enter the number of the " + "dataset above you wish to download:") + return options[int(choice)] + + def _download_data(self, dataset_url, save_path): + if not isdir(save_path): + os.makedirs(save_path) + + base = basename(dataset_url) + temp_save_path = join(save_path, base) + + with open(temp_save_path, "wb") as f: + r = requests.get(dataset_url) + f.write(r.content) + + if base.endswith('.tar.gz'): + obj = tarfile.open(temp_save_path) + elif base.endswith('.zip'): + obj = ZipFile(temp_save_path, 'r') + else: + raise ValueError("Unknown File Type: {0}.".format(base)) + + self._print("Unpacking Data...") + obj.extractall(save_path) + obj.close() + os.remove(temp_save_path) + + def get(self, save_path, dataset=None): + """ + + Download a dataset. + + Parameters: + save_path (str) -- A directory to save the data to. + dataset (str) -- (optional). A specific dataset to download. + Note: this must include the file extension. + If None, options will be presented for you + to choose from. + + Returns: + save_path_full (str) -- the absolute path to the downloaded data. + + """ + if dataset is None: + selected_dataset = self._present_options() + else: + selected_dataset = dataset + + save_path_full = join(save_path, selected_dataset.split('.')[0]) + + if isdir(save_path_full): + warn("\n'{0}' already exists. Voiding Download.".format( + save_path_full)) + else: + self._print('Downloading Data...') + url = "{0}/{1}".format(self.url, selected_dataset) + self._download_data(url, save_path=save_path) + + return abspath(save_path_full) diff --git a/talkingface/utils/live_speech_portraits/util/html.py b/talkingface/utils/live_speech_portraits/util/html.py new file mode 100644 index 00000000..10f2fbdc --- /dev/null +++ b/talkingface/utils/live_speech_portraits/util/html.py @@ -0,0 +1,67 @@ +import dominate +from dominate.tags import * +import os + + +class HTML: + def __init__(self, web_dir, title, reflesh=0): + self.title = title + self.web_dir = web_dir + self.img_dir = os.path.join(self.web_dir, 'images') + if not os.path.exists(self.web_dir): + os.makedirs(self.web_dir) + if not os.path.exists(self.img_dir): + os.makedirs(self.img_dir) + # print(self.img_dir) + + self.doc = dominate.document(title=title) + if reflesh > 0: + with self.doc.head: + meta(http_equiv="reflesh", content=str(reflesh)) + + def get_image_dir(self): + return self.img_dir + + def add_header(self, str): + with self.doc: + h3(str) + + def add_table(self, border=1): + self.t = table(border=border, style="table-layout: fixed;") + self.doc.add(self.t) + + def add_images(self, ims, txts, links, width=400, height=0): + self.add_table() + with self.t: + with tr(): + for im, txt, link in zip(ims, txts, links): + with td(style="word-wrap: break-word;", halign="center", valign="top"): + with p(): + with a(href=os.path.join('images', link)): + if height != 0: + img(style="width:%dpx;height:%dpx" % (width, height), src=os.path.join('images', im)) + else: + img(style="width:%dpx" % (width), src=os.path.join('images', im)) + br() + p(txt) + + def save(self): + html_file = '%s/index.html' % self.web_dir + f = open(html_file, 'wt') + f.write(self.doc.render()) + f.close() + + +if __name__ == '__main__': + html = HTML('web/', 'test_html') + html.add_header('hello world') + + ims = [] + txts = [] + links = [] + for n in range(4): + ims.append('image_%d.jpg' % n) + txts.append('text_%d' % n) + links.append('image_%d.jpg' % n) + html.add_images(ims, txts, links) + html.save() diff --git a/talkingface/utils/live_speech_portraits/util/image_pool.py b/talkingface/utils/live_speech_portraits/util/image_pool.py new file mode 100644 index 00000000..152ef5be --- /dev/null +++ b/talkingface/utils/live_speech_portraits/util/image_pool.py @@ -0,0 +1,32 @@ +import random +import numpy as np +import torch +from torch.autograd import Variable +class ImagePool(): + def __init__(self, pool_size): + self.pool_size = pool_size + if self.pool_size > 0: + self.num_imgs = 0 + self.images = [] + + def query(self, images): + if self.pool_size == 0: + return images + return_images = [] + for image in images.data: + image = torch.unsqueeze(image, 0) + if self.num_imgs < self.pool_size: + self.num_imgs = self.num_imgs + 1 + self.images.append(image) + return_images.append(image) + else: + p = random.uniform(0, 1) + if p > 0.5: + random_id = random.randint(0, self.pool_size-1) + tmp = self.images[random_id].clone() + self.images[random_id] = image + return_images.append(tmp) + else: + return_images.append(image) + return_images = Variable(torch.cat(return_images, 0)) + return return_images diff --git a/talkingface/utils/live_speech_portraits/util/util.py b/talkingface/utils/live_speech_portraits/util/util.py new file mode 100644 index 00000000..6bd1dabb --- /dev/null +++ b/talkingface/utils/live_speech_portraits/util/util.py @@ -0,0 +1,93 @@ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import inspect, re +import numpy as np +import os +import collections +from PIL import Image +import cv2 +from collections import OrderedDict + +from . import flow_viz + + + +# Converts a Tensor into a Numpy array +# |imtype|: the desired type of the converted numpy array +def tensor2im(image_tensor, imtype=np.uint8, normalize=True): + if isinstance(image_tensor, list): + image_numpy = [] + for i in range(len(image_tensor)): + image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) + return image_numpy + + if isinstance(image_tensor, torch.autograd.Variable): + image_tensor = image_tensor.data + if len(image_tensor.size()) == 5: + image_tensor = image_tensor[0, -1] + if len(image_tensor.size()) == 4: + image_tensor = image_tensor[0] + image_tensor = image_tensor[:3] + image_numpy = image_tensor.cpu().float().numpy() + if normalize: + image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 + else: + image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 + #image_numpy = (np.transpose(image_numpy, (1, 2, 0)) * std + mean) * 255.0 + image_numpy = np.clip(image_numpy, 0, 255) + if image_numpy.shape[2] == 1: + image_numpy = image_numpy[:,:,0] + return image_numpy.astype(imtype) + + +def tensor2flow(flo, imtype=np.uint8): + flo = flo[0].permute(1,2,0).cpu().detach().numpy() + flo = flow_viz.flow_to_image(flo) + return flo + + +def add_dummy_to_tensor(tensors, add_size=0): + if add_size == 0 or tensors is None: return tensors + if isinstance(tensors, list): + return [add_dummy_to_tensor(tensor, add_size) for tensor in tensors] + + if isinstance(tensors, torch.Tensor): + dummy = torch.zeros_like(tensors)[:add_size] + tensors = torch.cat([dummy, tensors]) + return tensors + +def remove_dummy_from_tensor(tensors, remove_size=0): + if remove_size == 0 or tensors is None: return tensors + if isinstance(tensors, list): + return [remove_dummy_from_tensor(tensor, remove_size) for tensor in tensors] + + if isinstance(tensors, torch.Tensor): + tensors = tensors[remove_size:] + return tensors + +def save_image(image_numpy, image_path): + image_pil = Image.fromarray(image_numpy) + image_pil.save(image_path) + +def print_numpy(x, val=True, shp=False): + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + +def mkdirs(paths): + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + +def mkdir(path): + if not os.path.exists(path): + os.makedirs(path) + diff --git a/talkingface/utils/live_speech_portraits/util/visualizer.py b/talkingface/utils/live_speech_portraits/util/visualizer.py new file mode 100644 index 00000000..fd25a6eb --- /dev/null +++ b/talkingface/utils/live_speech_portraits/util/visualizer.py @@ -0,0 +1,149 @@ +### Copyright (C) 2017 NVIDIA Corporation. All rights reserved. +### Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). +import numpy as np +import os +import time +from . import util +from . import html +import scipy.misc +try: + from StringIO import StringIO # Python 2.7 +except ImportError: + from io import BytesIO # Python 3.x + +class Visualizer(): + def __init__(self, opt): + self.opt = opt + self.tf_log = opt.tf_log + self.use_html = opt.isTrain and not opt.no_html + self.win_size = opt.display_winsize + self.name = opt.name + if opt.isTrain: + if self.tf_log: + from torch.utils.tensorboard import SummaryWriter + # import tensorflow as tf + # self.tf = tf + self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') + # self.writer = tf.summary.FileWriter(self.log_dir) + self.writer = SummaryWriter(self.log_dir, flush_secs=1) + + if self.use_html: + self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') + self.img_dir = os.path.join(self.web_dir, 'images') + print('create web directory %s...' % self.web_dir) + util.mkdirs([self.web_dir, self.img_dir]) + + self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') + with open(self.log_name, "a") as log_file: + now = time.strftime("%c") + log_file.write('================ Training Loss (%s) ================\n' % now) + + # |visuals|: dictionary of images to display or save + def display_current_results(self, visuals, epoch, step): +# if self.tf_log: # show images in tensorboard output +# img_summaries = [] +# for label, image_numpy in visuals.items(): +# # Write the image to a string +# try: +# s = StringIO() +# except: +# s = BytesIO() +# scipy.misc.toimage(image_numpy).save(s, format="jpeg") +# # Create an Image object +# img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) +# # Create a Summary value +# img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) +# +# # Create and write Summary +# summary = self.tf.Summary(value=img_summaries) +# self.writer.add_summary(summary, step) + + if self.use_html: # save images to a html file + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) + util.save_image(image_numpy[i], img_path) + else: + img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label)) + util.save_image(image_numpy, img_path) + + # update website + webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) + for n in range(epoch, 0, -1): + webpage.add_header('epoch [%d]' % n) + ims = [] + txts = [] + links = [] + + for label, image_numpy in visuals.items(): + if isinstance(image_numpy, list): + for i in range(len(image_numpy)): + img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) + ims.append(img_path) + txts.append(label+str(i)) + links.append(img_path) + else: + img_path = 'epoch%.3d_%s.jpg' % (n, label) + ims.append(img_path) + txts.append(label) + links.append(img_path) + if len(ims) < 5: + webpage.add_images(ims, txts, links, width=self.win_size) + else: + num = int(round(len(ims)/2.0)) + webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) + webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) + webpage.save() + + # errors: dictionary of error labels and values + def plot_current_errors(self, errors, step): + if self.tf_log: + for tag, value in errors.items(): +# summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) +# self.writer.add_summary(summary, step) + self.writer.add_scalar(tag, value, step) + + + # errors: same format as |errors| of plotCurrentErrors + def print_current_errors(self, epoch, i, errors, t): + message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) + for k, v in sorted(errors.items()): + if v != 0: + message += '%s: %.3f ' % (k, v) + + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + + # save image to the disk + def save_images(self, image_dir, visuals, image_path, webpage=None): + dirname = os.path.basename(os.path.dirname(image_path[0])) + image_dir = os.path.join(image_dir, dirname) + util.mkdir(image_dir) + name = image_path +# name = os.path.basename(image_path[0]) +# name = os.path.splitext(name)[0] + + if webpage is not None: + webpage.add_header(name) + ims, txts, links = [], [], [] + + for label, image_numpy in visuals.items(): + save_ext = 'jpg' + image_name = '%s_%s.%s' % (label, name, save_ext) + save_path = os.path.join(image_dir, image_name) + util.save_image(image_numpy, save_path) + + if webpage is not None: + ims.append(image_name) + txts.append(label) + links.append(image_name) + if webpage is not None: + webpage.add_images(ims, txts, links, width=self.win_size) + + def vis_print(self, message): + print(message) + with open(self.log_name, "a") as log_file: + log_file.write('%s\n' % message) + diff --git a/talkingface/utils/live_speech_portraits/utils.py b/talkingface/utils/live_speech_portraits/utils.py new file mode 100644 index 00000000..8eb2dfaf --- /dev/null +++ b/talkingface/utils/live_speech_portraits/utils.py @@ -0,0 +1,366 @@ +import sys + +sys.path.append("..") +from . import audio_funcs + +import numpy as np +from math import cos, sin +import torch +from numpy.linalg import solve +from scipy.ndimage import gaussian_filter1d +from sklearn.neighbors import KDTree +import time +from tqdm import tqdm + + +class camera(object): + def __init__(self, fx=0, fy=0, cx=0, cy=0): + self.name = 'default camera' + self.fx = fx + self.fy = fy + self.cx = cx + self.cy = cy + self.relative_rotation = np.diag([1, 1, 1]).astype(np.float32) + self.relative_translation = np.zeros(3, dtype=np.float32) + + # self.intrinsic = np.array([[self.fx, 0, self.cx], + # [0, self.fy, self.cy], + # [0, 0, 1]]) + + def intrinsic(self, trans_matrix=0): + ''' compute the intrinsic matrix + ''' + intrinsic = np.array([[self.fx, 0, self.cx], + [0, self.fy, self.cy], + [0, 0, 1]]) + + return intrinsic + + def relative(self): + ''' compute the relative transformation 4x4 matrix with respect to the + first camera kinect. specially the kinect's relative transformation + matrix is exact a identity matrix. + ''' + relative = np.eye(4, dtype=np.float32) + relative[:3, :3] = self.relative_rotation + relative[:3, 3] = self.relative_translation + + return relative + + def transform_intrinsic(self, transform_matrix): + ''' change the camera intrinsic matrix + transformed_intrinsic = transform_matrix * intrinsic + ''' + scale = transform_matrix[0, 0] + self.fx *= scale + self.fy *= scale + self.cx = scale * self.cx + transform_matrix[0, 2] + self.cy = scale * self.cy + transform_matrix[1, 2] + + +def compute_mel_one_sequence(audio, hop_length=int(16000 / 120), winlen=1 / 60, winstep=0.5 / 60, sr=16000, fps=60, + device='cpu'): + ''' compute mel for an audio sequence. + ''' + device = torch.device(device) + Audio2Mel_torch = audio_funcs.Audio2Mel(n_fft=512, hop_length=int(16000 / 120), win_length=int(16000 / 60), + sampling_rate=16000, + n_mel_channels=80, mel_fmin=90, mel_fmax=7600.0).to(device) + + nframe = int(audio.shape[0] / 16000 * 60) + mel_nframe = 2 * nframe + mel_frame_len = int(sr * winlen) + mel_frame_step = sr * winstep + + mel80s = np.zeros([mel_nframe, 80]) + for i in range(mel_nframe): + # for i in tqdm(range(mel_nframe)): + st = int(i * mel_frame_step) + audio_clip = audio[st: st + mel_frame_len] + if len(audio_clip) < mel_frame_len: + audio_clip = np.concatenate([audio_clip, np.zeros([mel_frame_len - len(audio_clip)])]) + audio_clip_device = torch.from_numpy(audio_clip).unsqueeze(0).unsqueeze(0).to(device).float() + mel80s[i] = Audio2Mel_torch(audio_clip_device).cpu().numpy()[0].T # [1, 80] + + return mel80s + + +def KNN(feats, feat_database, K=10): + ''' compute KNN for feat in feat base + ''' + tree = KDTree(feat_database, leaf_size=100000) + print('start computing KNN ...') + st = time.time() + dist, ind = tree.query(feats, k=K) + et = time.time() + print('Taken time: ', et - st) + + return dist, ind + + +def KNN_with_torch(feats, feat_database, K=10): + feats = torch.from_numpy(feats) # .cuda() + feat_database = torch.from_numpy(feat_database) # .cuda() + # Training + feat_base_norm = (feat_database ** 2).sum(-1) + # print('start computing KNN ...') + # st = time.time() + feats_norm = (feats ** 2).sum(-1) + diss = (feats_norm.view(-1, 1) + + feat_base_norm.view(1, -1) + - 2 * feats @ feat_database.t() # Rely on cuBLAS for better performance! + ) + ind = diss.topk(K, dim=1, largest=False).indices + # et = time.time() + # print('Taken time: ', et-st) + + return ind.cpu().numpy() + + +def solve_LLE_projection(feat, feat_base): + '''find LLE projection weights given feat base and target feat + Args: + feat: [ndim, ] target feat + feat_base: [K, ndim] K-nearest feat base + ======================================= + We need to solve the following function + ``` + min|| feat - \sum_0^k{w_i} * feat_base_i ||, s.t. \sum_0^k{w_i}=1 + ``` + equals to: + ft = w1*f1 + w2*f2 + ... + wk*fk, s.t. w1+w2+...+wk=1 + = (1-w2-...-wk)*f1 + w2*f2 + ... + wk*fk + ft-f1 = w2*(f2-f1) + w3*(f3-f1) + ... + wk*(fk-f1) + ft-f1 = (f2-f1, f3-f1, ..., fk-f1) dot (w2, w3, ..., wk).T + B = A dot w_, here, B: [ndim,] A: [ndim, k-1], w_: [k-1,] + Finally, + ft' = (1-w2-..wk, w2, ..., wk) dot (f1, f2, ..., fk) + ======================================= + Returns: + w: [K,] linear weights, sums to 1 + ft': [ndim,] reconstructed feats + ''' + K, ndim = feat_base.shape + if K == 1: + feat_fuse = feat_base[0] + w = np.array([1]) + else: + w = np.zeros(K) + B = feat - feat_base[0] # [ndim,] + A = (feat_base[1:] - feat_base[0]).T # [ndim, K-1] + AT = A.T + w[1:] = solve(AT.dot(A), AT.dot(B)) + w[0] = 1 - w[1:].sum() + feat_fuse = w.dot(feat_base) + + return w, feat_fuse + + +def compute_LLE_projection_frame(feats, feat_database, ind): + nframe = feats.shape[0] + feat_fuse = np.zeros_like(feats) + w = np.zeros([nframe, ind.shape[1]]) + current_K_feats = feat_database[ind] + w, feat_fuse = solve_LLE_projection(feats, current_K_feats) + + return w, feat_fuse + + +def compute_LLE_projection_all_frame(feats, feat_database, ind, nframe): + nframe = feats.shape[0] + feat_fuse = np.zeros_like(feats) + w = np.zeros([nframe, ind.shape[1]]) + for i in tqdm(range(nframe), desc='LLE projection'): + current_K_feats = feat_database[ind[i]] + w[i], feat_fuse[i] = solve_LLE_projection(feats[i], current_K_feats) + + return w, feat_fuse + + +def angle2matrix(angles, gradient='false'): + ''' get rotation matrix from three rotation angles(degree). right-handed. + Args: + angles: [3,]. x, y, z angles + x: pitch. positive for looking down. + y: yaw. positive for looking left. + z: roll. positive for tilting head right. + gradient(str): whether to compute gradient matrix: dR/d_x,y,z + Returns: + R: [3, 3]. rotation matrix. + ''' + x, y, z = np.deg2rad(angles[0]), np.deg2rad(angles[1]), np.deg2rad(angles[2]) + # x + Rx = np.array([[1, 0, 0], + [0, cos(x), -sin(x)], + [0, sin(x), cos(x)]]) + # y + Ry = np.array([[cos(y), 0, sin(y)], + [0, 1, 0], + [-sin(y), 0, cos(y)]]) + # z + Rz = np.array([[cos(z), -sin(z), 0], + [sin(z), cos(z), 0], + [0, 0, 1]]) + + R = Rz.dot(Ry.dot(Rx)) + # R=Rx.dot(Ry.dot(Rz)) + + if gradient != 'true': + return R.astype(np.float32) + elif gradient == 'true': + # gradident matrix + dRxdx = np.array([[0, 0, 0], + [0, -sin(x), -cos(x)], + [0, cos(x), -sin(x)]]) + dRdx = Rz.dot(Ry.dot(dRxdx)) * np.pi / 180 + dRydy = np.array([[-sin(y), 0, cos(y)], + [0, 0, 0], + [-cos(y), 0, -sin(y)]]) + dRdy = Rz.dot(dRydy.dot(Rx)) * np.pi / 180 + dRzdz = np.array([[-sin(z), -cos(z), 0], + [cos(z), -sin(z), 0], + [0, 0, 0]]) + dRdz = dRzdz.dot(Ry.dot(Rx)) * np.pi / 180 + + return R.astype(np.float32), [dRdx.astype(np.float32), dRdy.astype(np.float32), dRdz.astype(np.float32)] + + +def project_landmarks(camera_intrinsic, viewpoint_R, viewpoint_T, scale, headposes, pts_3d): + ''' project 2d landmarks given predicted 3d landmarks & headposes and user-defined + camera & viewpoint parameters + ''' + rot, trans = angle2matrix(headposes[:3]), headposes[3:][:, None] + pts3d_headpose = scale * rot.dot(pts_3d.T) + trans + pts3d_viewpoint = viewpoint_R.dot(pts3d_headpose) + viewpoint_T[:, None] + pts2d_project = camera_intrinsic.dot(pts3d_viewpoint) + pts2d_project[:2, :] /= pts2d_project[2, :] # divide z + pts2d_project = pts2d_project[:2, :].T + + return pts2d_project, rot, trans + + +def landmark_smooth_3d(pts3d, smooth_sigma=0, area='only_mouth'): + ''' smooth the input 3d landmarks using gaussian filters on each dimension. + Args: + pts3d: [N, 73, 3] + ''' + # per-landmark smooth + if not smooth_sigma == 0: + if area == 'all': + pts3d = gaussian_filter1d(pts3d.reshape(-1, 73 * 3), smooth_sigma, axis=0).reshape(-1, 73, 3) + elif area == 'only_mouth': + mouth_pts3d = pts3d[:, 46:64, :].copy() + mouth_pts3d = gaussian_filter1d(mouth_pts3d.reshape(-1, 18 * 3), smooth_sigma, axis=0).reshape(-1, 18, 3) + pts3d = gaussian_filter1d(pts3d.reshape(-1, 73 * 3), smooth_sigma, axis=0).reshape(-1, 73, 3) + pts3d[:, 46:64, :] = mouth_pts3d + + return pts3d + + +mouth_indices = list(range(46 * 2, 64 * 2)) +upper_outer_lip = list(range(47, 52)) +upper_inner_lip = [63, 62, 61] +lower_inner_lip = [58, 59, 60] +lower_outer_lip = list(range(57, 52, -1)) +lower_mouth = [53, 54, 55, 56, 57, 58, 59, 60] +upper_mouth = [46, 47, 48, 49, 50, 51, 52, 61, 62, 63] + + +def mouth_pts_AMP(pts3d, is_delta=True, method='XY', paras=[1, 1]): + ''' mouth region AMP to control the reaction amplitude. + method: 'XY', 'delta', 'XYZ', 'LowerMore' or 'CloseSmall' + ''' + if method == 'XY': + AMP_scale_x, AMP_scale_y = paras + if is_delta: + pts3d[:, 46:64, 0] *= AMP_scale_x + pts3d[:, 46:64, 1] *= AMP_scale_y + else: + mean_mouth3d_xy = pts3d[:, 46:64, :2].mean(axis=0) + pts3d[:, 46:64, 0] += (AMP_scale_x - 1) * (pts3d[:, 46:64, 0] - mean_mouth3d_xy[:, 0]) + pts3d[:, 46:64, 1] += (AMP_scale_y - 1) * (pts3d[:, 46:64, 1] - mean_mouth3d_xy[:, 1]) + elif method == 'delta': + AMP_scale_x, AMP_scale_y = paras + if is_delta: + diff = AMP_scale_x * (pts3d[1:, 46:64] - pts3d[:-1, 46:64]) + pts3d[1:, 46:64] += diff + + elif method == 'XYZ': + AMP_scale_x, AMP_scale_y, AMP_scale_z = paras + if is_delta: + pts3d[:, 46:64, 0] *= AMP_scale_x + pts3d[:, 46:64, 1] *= AMP_scale_y + pts3d[:, 46:64, 2] *= AMP_scale_z + + elif method == 'LowerMore': + upper_x, upper_y, upper_z, lower_x, lower_y, lower_z = paras + if is_delta: + pts3d[:, upper_mouth, 0] *= upper_x + pts3d[:, upper_mouth, 1] *= upper_y + pts3d[:, upper_mouth, 2] *= upper_z + pts3d[:, lower_mouth, 0] *= lower_x + pts3d[:, lower_mouth, 1] *= lower_y + pts3d[:, lower_mouth, 2] *= lower_z + + elif method == 'CloseSmall': + open_x, open_y, open_z, close_x, close_y, close_z = paras + nframe = pts3d.shape[0] + for i in tqdm(range(nframe), desc='AMP mouth..'): + if sum(pts3d[i, upper_mouth, 1] > 0) + sum(pts3d[i, lower_mouth, 1] < 0) > 16 * 0.3: + # open + pts3d[i, 46:64, 0] *= open_x + pts3d[i, 46:64, 1] *= open_y + pts3d[i, 46:64, 2] *= open_z + else: + # close + pts3d[:, 46:64, 0] *= close_x + pts3d[:, 46:64, 1] *= close_y + pts3d[:, 46:64, 2] *= close_z + + return pts3d + + +def solve_intersect_mouth(pts3d): + ''' solve the generated intersec lips, usually happens in mouth AMP usage. + Args: + pts3d: [N, 73, 3] + ''' + upper_inner = pts3d[:, upper_inner_lip] + lower_inner = pts3d[:, lower_inner_lip] + + lower_inner_y = lower_inner[:, :, 1] + upper_inner_y = upper_inner[:, :, 1] + # all three inner lip flip + flip = lower_inner_y > upper_inner_y + flip = np.where(flip.sum(axis=1) == 3)[0] + + # flip frames + inner_y_diff = lower_inner_y[flip] - upper_inner_y[flip] + half_inner_y_diff = inner_y_diff * 0.5 + # upper inner + pts3d[flip[:, None], upper_inner_lip, 1] += half_inner_y_diff + # lower inner + pts3d[flip[:, None], lower_inner_lip, 1] -= half_inner_y_diff + # upper outer + pts3d[flip[:, None], upper_outer_lip, 1] += half_inner_y_diff.mean() + # lower outer + pts3d[flip[:, None], lower_outer_lip, 1] -= half_inner_y_diff.mean() + + return pts3d + + +def headpose_smooth(headpose, smooth_sigmas=[0, 0], method='gaussian'): + rot_sigma, trans_sigma = smooth_sigmas + rot = gaussian_filter1d(headpose.reshape(-1, 6)[:, :3], rot_sigma, axis=0).reshape(-1, 3) + trans = gaussian_filter1d(headpose.reshape(-1, 6)[:, 3:], trans_sigma, axis=0).reshape(-1, 3) + headpose_smooth = np.concatenate([rot, trans], axis=1) + + return headpose_smooth + + + + + + + + diff --git a/talkingface/utils/utils.py b/talkingface/utils/utils.py index a5019491..ceadcf96 100644 --- a/talkingface/utils/utils.py +++ b/talkingface/utils/utils.py @@ -50,7 +50,6 @@ def get_model(model_name): "voice_conversion" ] - model_file_name = model_name.lower() model_module = None for submodule in model_submodule: @@ -445,11 +444,3 @@ def create_dataset(config): dataset_class = getattr(dataset_module, model_name+'Dataset') return dataset_class(config, config['train_filelist']), dataset_class(config, config['val_filelist']) - - - - - - - - diff --git "a/\345\221\275\344\273\244\350\241\214\346\210\252\345\233\276.png" "b/\345\221\275\344\273\244\350\241\214\346\210\252\345\233\276.png" new file mode 100644 index 00000000..1ca43b35 Binary files /dev/null and "b/\345\221\275\344\273\244\350\241\214\346\210\252\345\233\276.png" differ diff --git "a/\346\274\224\347\244\272\350\247\206\351\242\221.zip" "b/\346\274\224\347\244\272\350\247\206\351\242\221.zip" new file mode 100644 index 00000000..4cb4dd23 Binary files /dev/null and "b/\346\274\224\347\244\272\350\247\206\351\242\221.zip" differ diff --git "a/\347\274\272\345\244\261\350\256\255\347\273\203\347\232\204\345\216\237\345\233\240-2.png" "b/\347\274\272\345\244\261\350\256\255\347\273\203\347\232\204\345\216\237\345\233\240-2.png" new file mode 100644 index 00000000..dcbbfa89 Binary files /dev/null and "b/\347\274\272\345\244\261\350\256\255\347\273\203\347\232\204\345\216\237\345\233\240-2.png" differ diff --git "a/\347\274\272\345\244\261\350\256\255\347\273\203\347\232\204\345\216\237\345\233\240.png" "b/\347\274\272\345\244\261\350\256\255\347\273\203\347\232\204\345\216\237\345\233\240.png" new file mode 100644 index 00000000..efa6ab93 Binary files /dev/null and "b/\347\274\272\345\244\261\350\256\255\347\273\203\347\232\204\345\216\237\345\233\240.png" differ diff --git "a/\350\257\255\351\237\263\350\257\206\345\210\253-\350\256\272\346\226\207\345\244\215\347\216\260-LiveSpeechPortraits.docx" "b/\350\257\255\351\237\263\350\257\206\345\210\253-\350\256\272\346\226\207\345\244\215\347\216\260-LiveSpeechPortraits.docx" new file mode 100644 index 00000000..014b4a73 Binary files /dev/null and "b/\350\257\255\351\237\263\350\257\206\345\210\253-\350\256\272\346\226\207\345\244\215\347\216\260-LiveSpeechPortraits.docx" differ diff --git "a/\351\252\214\350\257\201\346\210\252\345\233\276.png" "b/\351\252\214\350\257\201\346\210\252\345\233\276.png" new file mode 100644 index 00000000..636c5e1e Binary files /dev/null and "b/\351\252\214\350\257\201\346\210\252\345\233\276.png" differ