diff --git a/README.md b/README.md index 7cd82328..4738901a 100644 --- a/README.md +++ b/README.md @@ -1,210 +1,44 @@ -# talkingface-toolkit -## 框架整体介绍 -### checkpoints -主要保存的是训练和评估模型所需要的额外的预训练模型,在对应文件夹的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/checkpoints/README.md)有更详细的介绍 +# talkingface-toolkit小组作业 +小组成员名单:高艺芙 贺芳琪 陈清扬(姓名排序为各自工作的前后逻辑顺序,与工作量无关) +# 模型选择: +我们选择复现的模型是[StarGAN-VC](https://github.com/kamepong/StarGAN-VC)模型,是专为语音转换任务设计的模型。它是StarGAN框架的延伸,该框架最初用于图像到图像的转换任务。StarGAN-VC专注于将一个说话者的语音特征转换为另一个说话者。 +# 作业环境 +python3.8 -### datset -存放数据集以及数据集预处理之后的数据,详细内容见dataset里的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/dataset/README.md) +开发环境:PyCharm2022.1.3 -### saved -存放训练过程中保存的模型checkpoint, 训练过程中保存模型时自动创建 +框架:PyTorch2.1 -### talkingface -主要功能模块,包括所有核心代码 +操作系统:Windows11,macOs -#### config -根据模型和数据集名称自动生成所有模型、数据集、训练、评估等相关的配置信息 -``` -config/ +语音识别库:Librosa0.10.1 -├── configurator.py +数据处理库:NumPy1.20.3 +# 数据集 +包含p225、226、227,分为test、train、val三种,详见dataset文件夹。 +# 运行指令 +总指令:python run_talkingface.py --model stargan --dataset vctk -``` -#### data -- dataprocess:模型特有的数据处理代码,(可以是对方仓库自己实现的音频特征提取、推理时的数据处理)。如果实现的模型有这个需求,就要建立一对应的文件 -- dataset:每个模型都要重载`torch.utils.data.Dataset` 用于加载数据。每个模型都要有一个`model_name+'_dataset.py'`文件. `__getitem__()`方法的返回值应处理成字典类型的数据。 (核心部分) -``` -data/ +前期数据准备指令:./recipes/run_train.sh -├── dataprocess +使用数据的指令:例如,当训练时要用对应的数据集,输入StarganDataset(config, config['train_filelist'])能调用train数据集 +# 实现功能 +总体功能:输入python run_talkingface.py --model stargan --dataset vctk可以直接进行数据处理、训练建模。 -| ├── wav2lip_process.py +前期数据准备:得到了StarGAN模型的配置文件参数,将其运用在toolkit中。 -| ├── xxxx_process.py - -├── dataset - -| ├── wav2lip_dataset.py - -| ├── xxx_dataset.py -``` - -#### evaluate -主要涉及模型评估的代码 -LSE metric 需要的数据是生成的视频列表 -SSIM metric 需要的数据是生成的视频和真实的视频列表 - -#### model -实现的模型的网络和对应的方法 (核心部分) - -主要分三类: -- audio-driven (音频驱动) -- image-driven (图像驱动) -- nerf-based (基于神经辐射场的方法) - -``` -model/ - -├── audio_driven_talkingface - -| ├── wav2lip.py - -├── image_driven_talkingface - -| ├── xxxx.py - -├── nerf_based_talkingface - -| ├── xxxx.py - -├── abstract_talkingface.py - -``` - -#### properties -保存默认配置文件,包括: -- 数据集配置文件 -- 模型配置文件 -- 通用配置文件 - -需要根据对应模型和数据集增加对应的配置文件,通用配置文件`overall.yaml`一般不做修改 -``` -properties/ - -├── dataset - -| ├── xxx.yaml - -├── model - -| ├── xxx.yaml - -├── overall.yaml - -``` - -#### quick_start -通用的启动文件,根据传入参数自动配置数据集和模型,然后训练和评估(一般不需要修改) -``` -quick_start/ - -├── quick_start.py - -``` - -#### trainer -训练、评估函数的主类。在trainer中,如果可以使用基类`Trainer`实现所有功能,则不需要写一个新的。如果模型训练有一些特有部分,则需要重载`Trainer`。需要重载部分可能主要集中于: `_train_epoch()`, `_valid_epoch()`。 重载的`Trainer`应该命名为:`{model_name}Trainer` -``` -trainer/ - -├── trainer.py - -``` - -#### utils -公用的工具类,包括`s3fd`人脸检测,视频抽帧、视频抽音频方法。还包括根据参数配置找对应的模型类、数据类等方法。 -一般不需要修改,但可以适当添加一些必须的且相对普遍的数据处理文件。 - -## 使用方法 -### 环境要求 -- `python=3.8` -- `torch==1.13.1+cu116`(gpu版,若设备不支持cuda可以使用cpu版) -- `numpy==1.20.3` -- `librosa==0.10.1` - -尽量保证上面几个包的版本一致 - -提供了两种配置其他环境的方法: -``` -pip install -r requirements.txt - -or - -conda env create -f environment.yml -``` - -建议使用conda虚拟环境!!! - -### 训练和评估 - -```bash -python run_talkingface.py --model=xxxx --dataset=xxxx (--other_parameters=xxxxxx) -``` - -### 权重文件 - -- LSE评估需要的权重: syncnet_v2.model [百度网盘下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc) -- wav2lip需要的lip expert 权重:lipsync_expert.pth [百度网下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc) - -## 可选论文: -### Aduio_driven talkingface -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| MakeItTalk | [paper](https://arxiv.org/abs/2004.12992) | [code](https://github.com/yzhou359/MakeItTalk) | -| MEAD | [paper](https://wywu.github.io/projects/MEAD/support/MEAD.pdf) | [code](https://github.com/uniBruce/Mead) | -| RhythmicHead | [paper](https://arxiv.org/pdf/2007.08547v1.pdf) | [code](https://github.com/lelechen63/Talking-head-Generation-with-Rhythmic-Head-Motion) | -| PC-AVS | [paper](https://arxiv.org/abs/2104.11116) | [code](https://github.com/Hangz-nju-cuhk/Talking-Face_PC-AVS) | -| EVP | [paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Ji_Audio-Driven_Emotional_Video_Portraits_CVPR_2021_paper.pdf) | [code](https://github.com/jixinya/EVP) | -| LSP | [paper](https://arxiv.org/abs/2109.10595) | [code](https://github.com/YuanxunLu/LiveSpeechPortraits) | -| EAMM | [paper](https://arxiv.org/pdf/2205.15278.pdf) | [code](https://github.com/jixinya/EAMM/) | -| DiffTalk | [paper](https://arxiv.org/abs/2301.03786) | [code](https://github.com/sstzal/DiffTalk) | -| TalkLip | [paper](https://arxiv.org/pdf/2303.17480.pdf) | [code](https://github.com/Sxjdwang/TalkLip) | -| EmoGen | [paper](https://arxiv.org/pdf/2303.11548.pdf) | [code](https://github.com/sahilg06/EmoGen) | -| SadTalker | [paper](https://arxiv.org/abs/2211.12194) | [code](https://github.com/OpenTalker/SadTalker) | -| HyperLips | [paper](https://arxiv.org/abs/2310.05720) | [code](https://github.com/semchan/HyperLips) | -| PHADTF | [paper](http://arxiv.org/abs/2002.10137) | [code](https://github.com/yiranran/Audio-driven-TalkingFace-HeadPose) | -| VideoReTalking | [paper](https://arxiv.org/abs/2211.14758) | [code](https://github.com/OpenTalker/video-retalking#videoretalking--audio-based-lip-synchronization-for-talking-head-video-editing-in-the-wild-) -| | - - - -### Image_driven talkingface -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| PIRenderer | [paper](https://arxiv.org/pdf/2109.08379.pdf) | [code](https://github.com/RenYurui/PIRender) | -| StyleHEAT | [paper](https://arxiv.org/pdf/2203.04036.pdf) | [code](https://github.com/OpenTalker/StyleHEAT) | -| MetaPortrait | [paper](https://arxiv.org/abs/2212.08062) | [code](https://github.com/Meta-Portrait/MetaPortrait) | -| | -### Nerf-based talkingface -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| AD-NeRF | [paper](https://arxiv.org/abs/2103.11078) | [code](https://github.com/YudongGuo/AD-NeRF) | -| GeneFace | [paper](https://arxiv.org/abs/2301.13430) | [code](https://github.com/yerfor/GeneFace) | -| DFRF | [paper](https://arxiv.org/abs/2207.11770) | [code](https://github.com/sstzal/DFRF) | -| | -### text_to_speech -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| VITS | [paper](https://arxiv.org/abs/2106.06103) | [code](https://github.com/jaywalnut310/vits) | -| Glow TTS | [paper](https://arxiv.org/abs/2005.11129) | [code](https://github.com/jaywalnut310/glow-tts) | -| FastSpeech2 | [paper](https://arxiv.org/abs/2006.04558v1) | [code](https://github.com/ming024/FastSpeech2) | -| StyleTTS2 | [paper](https://arxiv.org/abs/2306.07691) | [code](https://github.com/yl4579/StyleTTS2) | -| Grad-TTS | [paper](https://arxiv.org/abs/2105.06337) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS) | -| FastSpeech | [paper](https://arxiv.org/abs/1905.09263) | [code](https://github.com/xcmyz/FastSpeech) | -| | -### voice_conversion -| 模型简称 | 论文 | 代码仓库 | -|:--------:|:--------:|:--------:| -| StarGAN-VC | [paper](http://www.kecl.ntt.co.jp/people/kameoka.hirokazu/Demos/stargan-vc2/index.html) | [code](https://github.com/kamepong/StarGAN-VC) | -| Emo-StarGAN | [paper](https://www.researchgate.net/publication/373161292_Emo-StarGAN_A_Semi-Supervised_Any-to-Many_Non-Parallel_Emotion-Preserving_Voice_Conversion) | [code](https://github.com/suhitaghosh10/emo-stargan) | -| adaptive-VC | [paper](https://arxiv.org/abs/1904.05742) | [code](https://github.com/jjery2243542/adaptive_voice_conversion) | -| DiffVC | [paper](https://arxiv.org/abs/2109.13821) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC) | -| Assem-VC | [paper](https://arxiv.org/abs/2104.00931) | [code](https://github.com/maum-ai/assem-vc) | -| | - -## 作业要求 -- 确保可以仅在命令行输入模型和数据集名称就可以训练、验证。(部分仓库没有提供训练代码的,可以不训练) -- 每个组都要提交一个README文件,写明完成的功能、最终实现的训练、验证截图、所使用的依赖、成员分工等。 +数据部分:按照stargan的源码进行数据预处理,提取mel谱,转化为可用形式;划分数据集,按照80%、10%、10%的比例划分了三个数据集,生成了test.txt、train.txt、val.txt等文件,保存在dataset文件夹中。 +# 结果截图 +![](./md_img/1.jpg) +![](./md_img/2.jpg) +![](./md_img/3.jpg) +![](./md_img/4.jpg) +# 成员分工(按任务先后顺序编写) +高艺芙-1120213132-07022102:负责配置好实验环境,使用pip install -r requirements.txt语句更新安装包、解决报错问题,训练得到程序运行后的结果。准备实验数据,运行StarGan模型,将得到配置文件Arctic.json,StarGAN.json,将其转换为Arctic.yaml,StarGAN.yaml并整合到相应的文件结构中,整合至talkingface-toolkit/talkingface/properties下,最终打印调试。 +贺芳琪-1120210640-08012101:部分配置调整,主要负责data,数据预处理部分,进行了数据集划分。将stargan-vc中的dataset.py、compute_statistics.py、extract_features.py、normalize_features.py中有关数据预处理的代码整合到talkingface-toolkit/talkingface/data的dataset和dataprocess文件夹中,并修改了talkingface-toolkit中yaml里面有关数据预处理的参数。详细解释可见data部分的readme文件中。 + +陈清扬-1120213599-07112106:模型代码重构,训练代码重构,推理文件重构,配置文件调整,测试代码与bug修复,撰写实验报告 + diff --git a/convert.py b/convert.py new file mode 100644 index 00000000..3fefa46a --- /dev/null +++ b/convert.py @@ -0,0 +1,208 @@ +# Copyright 2021 Hirokazu Kameoka + +import os +import argparse +import torch +import json +import numpy as np +import re +import pickle +from tqdm import tqdm +import yaml + +import librosa +import soundfile as sf +from sklearn.preprocessing import StandardScaler + +import net +from extract_features import logmelfilterbank + +import sys +sys.path.append(os.path.abspath("pwg")) +from pwg.parallel_wavegan.utils import load_model +from pwg.parallel_wavegan.utils import read_hdf5 + +def audio_transform(wav_filepath, scaler, kwargs, device): + + trim_silence = kwargs['trim_silence'] + top_db = kwargs['top_db'] + flen = kwargs['flen'] + fshift = kwargs['fshift'] + fmin = kwargs['fmin'] + fmax = kwargs['fmax'] + num_mels = kwargs['num_mels'] + fs = kwargs['fs'] + + audio, fs_ = sf.read(wav_filepath) + if trim_silence: + #print('trimming.') + audio, _ = librosa.effects.trim(audio, top_db=top_db, frame_length=2048, hop_length=512) + if fs != fs_: + #print('resampling.') + audio = librosa.resample(audio, fs_, fs) + melspec_raw = logmelfilterbank(audio,fs, fft_size=flen,hop_size=fshift, + fmin=fmin, fmax=fmax, num_mels=num_mels) + melspec_raw = melspec_raw.astype(np.float32) # n_frame x n_mels + + melspec_norm = scaler.transform(melspec_raw) + melspec_norm = melspec_norm.T # n_mels x n_frame + + return torch.tensor(melspec_norm[None]).to(device, dtype=torch.float) + +def extract_num(s, p, ret=0): + search = p.search(s) + if search: + return int(search.groups()[0]) + else: + return ret + +def listdir_ext(dirpath,ext): + p = re.compile(r'(\d+)') + out = [] + for file in sorted(os.listdir(dirpath), key=lambda s: extract_num(s, p)): + if os.path.splitext(file)[1]==ext: + out.append(file) + return out + +def find_newest_model_file(model_dir, tag): + mfile_list = os.listdir(model_dir) + checkpoint = max([int(os.path.splitext(os.path.splitext(mfile)[0])[0]) for mfile in mfile_list if mfile.endswith('.{}.pt'.format(tag))]) + return checkpoint + + +def synthesis(melspec, model_nv, nv_config, savepath, device): + ## Parallel WaveGAN / MelGAN + melspec = torch.tensor(melspec, dtype=torch.float).to(device) + #start = time.time() + x = model_nv.inference(melspec).view(-1) + #elapsed_time = time.time() - start + #rtf2 = elapsed_time/audio_len + #print ("elapsed_time (waveform generation): {0}".format(elapsed_time) + "[sec]") + #print ("real time factor (waveform generation): {0}".format(rtf2)) + + # save as PCM 16 bit wav file + if not os.path.exists(os.path.dirname(savepath)): + os.makedirs(os.path.dirname(savepath)) + sf.write(savepath, x.detach().cpu().clone().numpy(), nv_config["sampling_rate"], "PCM_16") + +def main(): + parser = argparse.ArgumentParser(description='Testing StarGAN-VC') + parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU)') + parser.add_argument('-i', '--input', type=str, default='/misc/raid58/kameoka.hirokazu/python/db/arctic/wav/test', + help='root data folder that contains the wav files of input speech') + parser.add_argument('-o', '--out', type=str, default='./out/arctic', + help='root data folder where the wav files of the converted speech will be saved.') + parser.add_argument('--dataconf', type=str, default='./dump/arctic/data_config.json') + parser.add_argument('--stat', type=str, default='./dump/arctic/stat.pkl', help='state file used for normalization') + parser.add_argument('--model_rootdir', '-mdir', type=str, default='./model/arctic/', help='model file directory') + parser.add_argument('--checkpoint', '-ckpt', type=int, default=0, help='model checkpoint to load (0 indicates the newest model)') + parser.add_argument('--experiment_name', '-exp', default='experiment1', type=str, help='experiment name') + parser.add_argument('--vocoder', '-voc', default='hifigan.v1', type=str, + help='neural vocoder type name (e.g., hifigan.v1, hifigan.v2, parallel_wavegan.v1)') + parser.add_argument('--voc_dir', '-vdir', type=str, default='pwg/egs/arctic_4spk_flen64ms_fshift8ms/voc1', + help='directory of trained neural vocoder') + args = parser.parse_args() + + # Set up GPU + if torch.cuda.is_available() and args.gpu >= 0: + device = torch.device('cuda:%d' % args.gpu) + else: + device = torch.device('cpu') + if device.type == 'cuda': + torch.cuda.set_device(device) + + input_dir = args.input + data_config_path = args.dataconf + model_config_path = os.path.join(args.model_rootdir,args.experiment_name,'model_config.json') + with open(data_config_path) as f: + data_config = json.load(f) + with open(model_config_path) as f: + model_config = json.load(f) + checkpoint = args.checkpoint + + num_mels = model_config['num_mels'] + arch_type = model_config['arch_type'] + loss_type = model_config['loss_type'] + n_spk = model_config['n_spk'] + trg_spk_list = model_config['spk_list'] + zdim = model_config['zdim'] + hdim = model_config['hdim'] + mdim = model_config['mdim'] + sdim = model_config['sdim'] + normtype = model_config['normtype'] + src_conditioning = model_config['src_conditioning'] + + stat_filepath = args.stat + melspec_scaler = StandardScaler() + if os.path.exists(stat_filepath): + with open(stat_filepath, mode='rb') as f: + melspec_scaler = pickle.load(f) + print('Loaded mel-spectrogram statistics successfully.') + else: + print('Stat file not found.') + + # Set up main model + gen = net.Generator1(num_mels, n_spk, zdim, hdim, sdim, normtype, src_conditioning) if arch_type=='conv' else net.Generator1(num_mels, n_spk, zdim, hdim, sdim, normtype, src_conditioning) + dis = net.Discriminator1(num_mels, n_spk, mdim, normtype) if arch_type=='conv' else net.Discriminator1(num_mels, n_spk, mdim, normtype) + models = { + 'gen': gen, + 'dis': dis + } + models['stargan'] = net.StarGAN(models['gen'],models['dis'],n_spk,loss_type) + + for tag in ['gen', 'dis']: + model_dir = os.path.join(args.model_rootdir,args.experiment_name) + vc_checkpoint_idx = find_newest_model_file(model_dir, tag) if checkpoint <= 0 else checkpoint + mfilename = '{}.{}.pt'.format(vc_checkpoint_idx,tag) + path = os.path.join(args.model_rootdir,args.experiment_name,mfilename) + if path is not None: + model_checkpoint = torch.load(path, map_location=device) + models[tag].load_state_dict(model_checkpoint['model_state_dict']) + print('{}: {}'.format(tag, os.path.abspath(path))) + + for tag in ['gen', 'dis']: + #models[tag].to(device).eval() + models[tag].to(device).train(mode=True) + + # Set up nv + vocoder = args.vocoder + voc_dir = args.voc_dir + voc_yaml_path = os.path.join(voc_dir,'conf', '{}.yaml'.format(vocoder)) + checkpointlist = listdir_ext( + os.path.join(voc_dir,'exp','train_nodev_all_{}'.format(vocoder)),'.pkl') + nv_checkpoint = os.path.join(voc_dir,'exp', + 'train_nodev_all_{}'.format(vocoder), + checkpointlist[-1]) # Find and use the newest checkpoint model. + print('vocoder: {}'.format(os.path.abspath(nv_checkpoint))) + + with open(voc_yaml_path) as f: + nv_config = yaml.load(f, Loader=yaml.Loader) + nv_config.update(vars(args)) + model_nv = load_model(nv_checkpoint, nv_config) + model_nv.remove_weight_norm() + model_nv = model_nv.eval().to(device) + + src_spk_list = sorted(os.listdir(input_dir)) + + for i, src_spk in enumerate(src_spk_list): + src_wav_dir = os.path.join(input_dir, src_spk) + for j, trg_spk in enumerate(trg_spk_list): + if src_spk != trg_spk: + print('Converting {}2{}...'.format(src_spk, trg_spk)) + for n, src_wav_filename in enumerate(os.listdir(src_wav_dir)): + src_wav_filepath = os.path.join(src_wav_dir, src_wav_filename) + src_melspec = audio_transform(src_wav_filepath, melspec_scaler, data_config, device) + k_t = j + k_s = i if src_conditioning else None + + conv_melspec = models['stargan'](src_melspec, k_t, k_s) + + conv_melspec = conv_melspec[0,:,:].detach().cpu().clone().numpy() + conv_melspec = conv_melspec.T # n_frames x n_mels + + out_wavpath = os.path.join(args.out,args.experiment_name,'{}'.format(vc_checkpoint_idx),vocoder,'{}2{}'.format(src_spk,trg_spk), src_wav_filename) + synthesis(conv_melspec, model_nv, nv_config, out_wavpath, device) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/dataset.py b/dataset.py new file mode 100644 index 00000000..471c89d5 --- /dev/null +++ b/dataset.py @@ -0,0 +1,65 @@ +# Copyright 2021 Hirokazu Kameoka + +import os +import numpy as np +import torch +from torch.utils.data import Dataset +import h5py +import math +import random + +def walk_files(root, extension): + for path, dirs, files in os.walk(root): + for file in files: + if file.endswith(extension): + yield os.path.join(path, file) + +class MultiDomain_Dataset(Dataset): + def __init__(self, *feat_dirs): + self.n_domain = len(feat_dirs) + self.filenames_all = [[os.path.join(d,t) for t in sorted(os.listdir(d))] for d in feat_dirs] + #self.filenames_all = [[t for t in walk_files(d, '.h5')] for d in feat_dirs] + self.feat_dirs = feat_dirs + + def __len__(self): + return min(len(f) for f in self.filenames_all) + + def __getitem__(self, idx): + melspec_list = [] + for d in range(self.n_domain): + with h5py.File(self.filenames_all[d][idx], "r") as f: + melspec = f["melspec"][()] # n_freq x n_time + melspec_list.append(melspec) + return melspec_list + +def collate_fn(batch): + #batch[b][s]: melspec (n_freq x n_frame) + #b: batch size + #s: speaker ID + + batchsize = len(batch) + n_spk = len(batch[0]) + melspec_list = [[batch[b][s] for b in range(batchsize)] for s in range(n_spk)] + #melspec_list[s][b]: melspec (n_freq x n_frame) + #s: speaker ID + #b: batch size + + n_freq = melspec_list[0][0].shape[0] + + X_list = [] + for s in range(n_spk): + maxlen=0 + for b in range(batchsize): + if maxlen= 0: + device = torch.device('cuda:%d' % args.gpu) + else: + device = torch.device('cpu') + if device.type == 'cuda': + torch.cuda.set_device(device) + + input_dir = args.input + data_config_path = args.dataconf + model_config_path = os.path.join(args.model_rootdir,args.experiment_name,'model_config.json') + with open(data_config_path) as f: + data_config = json.load(f) + with open(model_config_path) as f: + model_config = json.load(f) + checkpoint = args.checkpoint + + num_mels = model_config['num_mels'] + arch_type = model_config['arch_type'] + loss_type = model_config['loss_type'] + n_spk = model_config['n_spk'] + trg_spk_list = model_config['spk_list'] + zdim = model_config['zdim'] + hdim = model_config['hdim'] + mdim = model_config['mdim'] + sdim = model_config['sdim'] + normtype = model_config['normtype'] + src_conditioning = model_config['src_conditioning'] + + stat_filepath = args.stat + melspec_scaler = StandardScaler() + if os.path.exists(stat_filepath): + with open(stat_filepath, mode='rb') as f: + melspec_scaler = pickle.load(f) + print('Loaded mel-spectrogram statistics successfully.') + else: + print('Stat file not found.') + + # Set up main model + gen = net.Generator1(num_mels, n_spk, zdim, hdim, sdim, normtype, src_conditioning) if arch_type=='conv' else net.Generator1(num_mels, n_spk, zdim, hdim, sdim, normtype, src_conditioning) + dis = net.Discriminator1(num_mels, n_spk, mdim, normtype) if arch_type=='conv' else net.Discriminator1(num_mels, n_spk, mdim, normtype) + models = { + 'gen': gen, + 'dis': dis + } + models['stargan'] = net.StarGAN(models['gen'],models['dis'],n_spk,loss_type) + + for tag in ['gen', 'dis']: + model_dir = os.path.join(args.model_rootdir,args.experiment_name) + vc_checkpoint_idx = find_newest_model_file(model_dir, tag) if checkpoint <= 0 else checkpoint + mfilename = '{}.{}.pt'.format(vc_checkpoint_idx,tag) + path = os.path.join(args.model_rootdir,args.experiment_name,mfilename) + if path is not None: + model_checkpoint = torch.load(path, map_location=device) + models[tag].load_state_dict(model_checkpoint['model_state_dict']) + print('{}: {}'.format(tag, os.path.abspath(path))) + + for tag in ['gen', 'dis']: + #models[tag].to(device).eval() + models[tag].to(device).train(mode=True) + + # Set up nv + vocoder = args.vocoder + voc_dir = args.voc_dir + voc_yaml_path = os.path.join(voc_dir,'conf', '{}.yaml'.format(vocoder)) + checkpointlist = listdir_ext( + os.path.join(voc_dir,'exp','train_nodev_all_{}'.format(vocoder)),'.pkl') + nv_checkpoint = os.path.join(voc_dir,'exp', + 'train_nodev_all_{}'.format(vocoder), + checkpointlist[-1]) # Find and use the newest checkpoint model. + print('vocoder: {}'.format(os.path.abspath(nv_checkpoint))) + + with open(voc_yaml_path) as f: + nv_config = yaml.load(f, Loader=yaml.Loader) + nv_config.update(vars(args)) + model_nv = load_model(nv_checkpoint, nv_config) + model_nv.remove_weight_norm() + model_nv = model_nv.eval().to(device) + + src_spk_list = sorted(os.listdir(input_dir)) + + for i, src_spk in enumerate(src_spk_list): + src_wav_dir = os.path.join(input_dir, src_spk) + for j, trg_spk in enumerate(trg_spk_list): + if src_spk != trg_spk: + print('Converting {}2{}...'.format(src_spk, trg_spk)) + for n, src_wav_filename in enumerate(os.listdir(src_wav_dir)): + src_wav_filepath = os.path.join(src_wav_dir, src_wav_filename) + src_melspec = audio_transform(src_wav_filepath, melspec_scaler, data_config, device) + k_t = j + k_s = i if src_conditioning else None + + conv_melspec = models['stargan'](src_melspec, k_t, k_s) + + conv_melspec = conv_melspec[0,:,:].detach().cpu().clone().numpy() + conv_melspec = conv_melspec.T # n_frames x n_mels + + out_wavpath = os.path.join(args.out,args.experiment_name,'{}'.format(vc_checkpoint_idx),vocoder,'{}2{}'.format(src_spk,trg_spk), src_wav_filename) + synthesis(conv_melspec, model_nv, nv_config, out_wavpath, device) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/models/stargan/dataset.py b/models/stargan/dataset.py new file mode 100644 index 00000000..471c89d5 --- /dev/null +++ b/models/stargan/dataset.py @@ -0,0 +1,65 @@ +# Copyright 2021 Hirokazu Kameoka + +import os +import numpy as np +import torch +from torch.utils.data import Dataset +import h5py +import math +import random + +def walk_files(root, extension): + for path, dirs, files in os.walk(root): + for file in files: + if file.endswith(extension): + yield os.path.join(path, file) + +class MultiDomain_Dataset(Dataset): + def __init__(self, *feat_dirs): + self.n_domain = len(feat_dirs) + self.filenames_all = [[os.path.join(d,t) for t in sorted(os.listdir(d))] for d in feat_dirs] + #self.filenames_all = [[t for t in walk_files(d, '.h5')] for d in feat_dirs] + self.feat_dirs = feat_dirs + + def __len__(self): + return min(len(f) for f in self.filenames_all) + + def __getitem__(self, idx): + melspec_list = [] + for d in range(self.n_domain): + with h5py.File(self.filenames_all[d][idx], "r") as f: + melspec = f["melspec"][()] # n_freq x n_time + melspec_list.append(melspec) + return melspec_list + +def collate_fn(batch): + #batch[b][s]: melspec (n_freq x n_frame) + #b: batch size + #s: speaker ID + + batchsize = len(batch) + n_spk = len(batch[0]) + melspec_list = [[batch[b][s] for b in range(batchsize)] for s in range(n_spk)] + #melspec_list[s][b]: melspec (n_freq x n_frame) + #s: speaker ID + #b: batch size + + n_freq = melspec_list[0][0].shape[0] + + X_list = [] + for s in range(n_spk): + maxlen=0 + for b in range(batchsize): + if maxlen n_frame_: + x = nn.ReplicationPad1d((0, n_frame-n_frame_))(x) + return self.gen(x, k_t, k_s)[:,:,0:n_frame_] + + def calc_advloss_g(self, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts): + df_adv_ss = df_adv_ss.permute(0,2,1).reshape(-1,1) + df_adv_st = df_adv_st.permute(0,2,1).reshape(-1,1) + df_adv_tt = df_adv_tt.permute(0,2,1).reshape(-1,1) + df_adv_ts = df_adv_ts.permute(0,2,1).reshape(-1,1) + + if self.loss_type=='wgan': + # Wasserstein GAN with gradient penalty (WGAN-GP) + AdvLoss_g = ( + torch.sum(-df_adv_ss) + + torch.sum(-df_adv_st) + + torch.sum(-df_adv_tt) + + torch.sum(-df_adv_ts) + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + + elif self.loss_type=='lsgan': + # Least squares GAN (LSGAN) + AdvLoss_g = 0.5 * ( + torch.sum((df_adv_ss - torch.ones_like(df_adv_ss))**2) + + torch.sum((df_adv_st - torch.ones_like(df_adv_st))**2) + + torch.sum((df_adv_tt - torch.ones_like(df_adv_tt))**2) + + torch.sum((df_adv_ts - torch.ones_like(df_adv_ts))**2) + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + + elif self.loss_type=='cgan': + # Regular GAN with the sigmoid cross-entropy criterion (CGAN) + AdvLoss_g = ( + F.binary_cross_entropy_with_logits(df_adv_ss, torch.ones_like(df_adv_ss), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_st, torch.ones_like(df_adv_st), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_tt, torch.ones_like(df_adv_tt), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_ts, torch.ones_like(df_adv_ts), reduction='sum') + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + + return AdvLoss_g + + def calc_clsloss_g(self, df_cls_ss, df_cls_st, df_cls_tt, df_cls_ts, k_s, k_t): + device = df_cls_ss.device + + df_cls_ss = df_cls_ss.permute(0,2,1).reshape(-1,self.n_spk) + df_cls_st = df_cls_st.permute(0,2,1).reshape(-1,self.n_spk) + df_cls_tt = df_cls_tt.permute(0,2,1).reshape(-1,self.n_spk) + df_cls_ts = df_cls_ts.permute(0,2,1).reshape(-1,self.n_spk) + + cf_ss = k_s*torch.ones(len(df_cls_ss), device=device, dtype=torch.long) + cf_st = k_t*torch.ones(len(df_cls_st), device=device, dtype=torch.long) + cf_tt = k_t*torch.ones(len(df_cls_tt), device=device, dtype=torch.long) + cf_ts = k_s*torch.ones(len(df_cls_ts), device=device, dtype=torch.long) + + ClsLoss_g = ( + F.cross_entropy(df_cls_ss, cf_ss, reduction='sum') + + F.cross_entropy(df_cls_st, cf_st, reduction='sum') + + F.cross_entropy(df_cls_tt, cf_tt, reduction='sum') + + F.cross_entropy(df_cls_ts, cf_ts, reduction='sum') + ) / (df_cls_ss.numel() + df_cls_st.numel() + df_cls_tt.numel() + df_cls_ts.numel()) + + return ClsLoss_g + + def calc_advloss_d(self, x_s, x_t, xf_ts, xf_st, dr_adv_s, dr_adv_t, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts): + device = x_s.device + B_s = len(x_s) + B_t = len(x_t) + + dr_adv_s = dr_adv_s.permute(0,2,1).reshape(-1,1) + dr_adv_t = dr_adv_t.permute(0,2,1).reshape(-1,1) + df_adv_ss = df_adv_ss.permute(0,2,1).reshape(-1,1) + df_adv_st = df_adv_st.permute(0,2,1).reshape(-1,1) + df_adv_tt = df_adv_tt.permute(0,2,1).reshape(-1,1) + df_adv_ts = df_adv_ts.permute(0,2,1).reshape(-1,1) + + if self.loss_type=='wgan': + # Wasserstein GAN with gradient penalty (WGAN-GP) + AdvLoss_d_r = ( + torch.sum(-dr_adv_s) + + torch.sum(-dr_adv_t) + ) / (dr_adv_s.numel() + dr_adv_t.numel()) + AdvLoss_d_f = ( + torch.sum(df_adv_ss) + + torch.sum(df_adv_st) + + torch.sum(df_adv_tt) + + torch.sum(df_adv_ts) + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f + + elif self.loss_type=='lsgan': + # Least squares GAN (LSGAN) + + AdvLoss_d_r = 0.5 * ( + torch.sum((dr_adv_s - torch.ones_like(dr_adv_s))**2) + + torch.sum((dr_adv_t - torch.ones_like(dr_adv_t))**2) + ) / (dr_adv_s.numel() + dr_adv_t.numel()) + AdvLoss_d_f = 0.5 * ( + torch.sum(df_adv_ss**2) + + torch.sum(df_adv_st**2) + + torch.sum(df_adv_tt**2) + + torch.sum(df_adv_ts**2) + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f + + elif self.loss_type=='cgan': + # Regular GAN with sigmoid cross-entropy criterion (CGAN) + AdvLoss_d_r = ( + F.binary_cross_entropy_with_logits(dr_adv_s, torch.ones_like(dr_adv_s), reduction='sum') + + F.binary_cross_entropy_with_logits(dr_adv_t, torch.ones_like(dr_adv_t), reduction='sum') + ) / (dr_adv_s.numel() + dr_adv_t.numel()) + AdvLoss_d_f = ( + F.binary_cross_entropy_with_logits(df_adv_ss, torch.zeros_like(df_adv_ss), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_st, torch.zeros_like(df_adv_st), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_tt, torch.zeros_like(df_adv_tt), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_ts, torch.zeros_like(df_adv_ts), reduction='sum') + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f + + # Gradient penalty loss + alpha_t = torch.rand(B_t, 1, 1, requires_grad=True).to(device) + interpolates = alpha_t * x_t + ((1 - alpha_t) * xf_ts) + interpolates = interpolates.to(device) + disc_interpolates, _ = self.dis(interpolates) + disc_interpolates = torch.sum(disc_interpolates) + gradients = torch.autograd.grad(outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + gradnorm = torch.sqrt(torch.sum(gradients * gradients, (1, 2))) + loss_gp_t = ((gradnorm - 1)**2).mean() + + alpha_s = torch.rand(B_s, 1, 1, requires_grad=True).to(device) + interpolates = alpha_s * x_s + ((1 - alpha_s) * xf_st) + interpolates = interpolates.to(device) + interpolates = torch.autograd.Variable(interpolates, requires_grad=True) + disc_interpolates, _ = self.dis(interpolates) + disc_interpolates = torch.sum(disc_interpolates) + gradients = torch.autograd.grad(outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + gradnorm = torch.sqrt(torch.sum(gradients * gradients, (1, 2))) + loss_gp_s = ((gradnorm - 1)**2).mean() + + GradLoss_d = loss_gp_s + loss_gp_t + + return AdvLoss_d, GradLoss_d + + def calc_clsloss_d(self, dr_cls_s, dr_cls_t, k_s, k_t): + device = dr_cls_s.device + + dr_cls_s = dr_cls_s.permute(0,2,1).reshape(-1,self.n_spk) + dr_cls_t = dr_cls_t.permute(0,2,1).reshape(-1,self.n_spk) + + cr_s = k_s*torch.ones(len(dr_cls_s), device=device, dtype=torch.long) + cr_t = k_t*torch.ones(len(dr_cls_t), device=device, dtype=torch.long) + + ClsLoss_d = ( + F.cross_entropy(dr_cls_s, cr_s, reduction='sum') + + F.cross_entropy(dr_cls_t, cr_t, reduction='sum') + ) / (dr_cls_s.numel() + dr_cls_t.numel()) + + return ClsLoss_d + + def calc_gen_loss(self, x_s, x_t, k_s, k_t): + # Generator outputs + xf_ss = self.gen(x_s, k_s, k_s) + xf_ts = self.gen(x_t, k_s, k_t) + xf_tt = self.gen(x_t, k_t, k_t) + xf_st = self.gen(x_s, k_t, k_s) + + # Discriminator outputs + df_adv_ss, df_cls_ss = self.dis(xf_ss) + df_adv_st, df_cls_st = self.dis(xf_st) + df_adv_tt, df_cls_tt = self.dis(xf_tt) + df_adv_ts, df_cls_ts = self.dis(xf_ts) + + # Adversarial loss + AdvLoss_g = self.calc_advloss_g(df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts) + + # Classifier loss + ClsLoss_g = self.calc_clsloss_g(df_cls_ss, df_cls_st, df_cls_tt, df_cls_ts, k_s, k_t) + + # Cycle-consistency loss + CycLoss = ( + torch.sum(torch.abs(x_s - self.gen(xf_st, k_s, k_t))) + + torch.sum(torch.abs(x_t - self.gen(xf_ts, k_t, k_s))) + ) / (x_s.numel() + x_t.numel()) + + # Reconstruction loss + RecLoss = ( + torch.sum(torch.abs(x_s - xf_ss)) + + torch.sum(torch.abs(x_t - xf_tt)) + ) / (x_s.numel() + x_t.numel()) + + return AdvLoss_g, ClsLoss_g, CycLoss, RecLoss + + def calc_dis_loss(self, x_s, x_t, k_s, k_t): + device = x_s.device + + # Generator outputs + xf_ss = self.gen(x_s, k_s, k_s) + xf_ts = self.gen(x_t, k_s, k_t) + xf_tt = self.gen(x_t, k_t, k_t) + xf_st = self.gen(x_s, k_t, k_s) + + # Discriminator outputs + dr_adv_s, dr_cls_s = self.dis(x_s) + dr_adv_t, dr_cls_t = self.dis(x_t) + df_adv_ss, _ = self.dis(xf_ss) + df_adv_st, _ = self.dis(xf_st) + df_adv_tt, _ = self.dis(xf_tt) + df_adv_ts, _ = self.dis(xf_ts) + + # Adversarial loss + AdvLoss_d, GradLoss_d = self.calc_advloss_d(x_s, x_t, xf_ts, xf_st, dr_adv_s, dr_adv_t, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts) + + # Classifier loss + ClsLoss_d = self.calc_clsloss_d(dr_cls_s, dr_cls_t, k_s, k_t) + + return AdvLoss_d, GradLoss_d, ClsLoss_d diff --git a/models/stargan/normalize_features.py b/models/stargan/normalize_features.py new file mode 100644 index 00000000..0b4aa157 --- /dev/null +++ b/models/stargan/normalize_features.py @@ -0,0 +1,94 @@ +import argparse +import joblib +import logging +import os + +import h5py +import numpy as np +from tqdm import tqdm +from sklearn.preprocessing import StandardScaler +import pickle + +def walk_files(root, extension): + for path, dirs, files in os.walk(root): + for file in files: + if file.endswith(extension): + yield os.path.join(path, file) + +def melspec_transform(melspec, scaler): + # melspec.shape: (n_freq, n_time) + # scaler.transform assumes the first axis to be the time axis + melspec = scaler.transform(melspec.T) + #import pdb;pdb.set_trace() # Breakpoint + melspec = melspec.T + return melspec + +def normalize_features(src_filepath, dst_filepath, melspec_transform): + try: + with h5py.File(src_filepath, "r") as f: + melspec = f["melspec"][()] + melspec = melspec_transform(melspec) + + if not os.path.exists(os.path.dirname(dst_filepath)): + os.makedirs(os.path.dirname(dst_filepath), exist_ok=True) + with h5py.File(dst_filepath, "w") as f: + f.create_dataset("melspec", data=melspec) + + #logging.info(f"{dst_filepath}...[{melspec.shape}].") + return melspec.shape + + except: + logging.info(f"{dst_filepath}...failed.") + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--src', type=str, + default='./dump/arctic/feat/train', + help='data folder that contains the raw features extracted from VoxCeleb2 Dataset') + parser.add_argument('--dst', type=str, default='./dump/arctic/norm_feat/train', + help='data folder where the normalized features are stored') + parser.add_argument('--stat', type=str, default='./dump/arctic/stat.pkl', + help='state file used for normalization') + parser.add_argument('--ext', type=str, default='.h5') + args = parser.parse_args() + + src = args.src + dst = args.dst + ext = args.ext + stat_filepath = args.stat + + fmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s' + datafmt = '%m/%d/%Y %I:%M:%S' + logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datafmt) + + melspec_scaler = StandardScaler() + if os.path.exists(stat_filepath): + with open(stat_filepath, mode='rb') as f: + melspec_scaler = pickle.load(f) + print('Loaded mel-spectrogram statistics successfully.') + else: + print('Stat file not found.') + + root = src + fargs_list = [ + [ + f, + f.replace(src, dst), + lambda x: melspec_transform(x, melspec_scaler), + ] + for f in walk_files(root, ext) + ] + + #import pdb;pdb.set_trace() # Breakpoint + # debug + #normalize_features(*fargs_list[0]) + # test + #results = joblib.Parallel(n_jobs=-1)( + # joblib.delayed(normalize_features)(*f) for f in tqdm(fargs_list) + #) + results = joblib.Parallel(n_jobs=16)( + joblib.delayed(normalize_features)(*f) for f in tqdm(fargs_list) + ) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/models/stargan/train_stargan_model.py b/models/stargan/train_stargan_model.py new file mode 100644 index 00000000..d6e7dc05 --- /dev/null +++ b/models/stargan/train_stargan_model.py @@ -0,0 +1,467 @@ +import itertools +import torch +from torch import optim +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader +import argparse +import joblib +import logging +import os +import warnings +import json + +import librosa +import soundfile as sf +import h5py +import numpy as np +from tqdm import tqdm +from sklearn.preprocessing import StandardScaler +import pickle +from dataset import MultiDomain_Dataset, collate_fn +import net + +def walk_files(root, extension): + for path, dirs, files in os.walk(root): + for file in files: + if file.endswith(extension): + yield os.path.join(path, file) + +def logmelfilterbank(audio, + sampling_rate, + fft_size=1024, + hop_size=256, + win_length=None, + window="hann", + num_mels=80, + fmin=None, + fmax=None, + eps=1e-10, + ): + x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size, + win_length=win_length, window=window, pad_mode="reflect") + spc = np.abs(x_stft).T # (#frames, #bins) + + # get mel basis + fmin = 0 if fmin is None else fmin + fmax = sampling_rate / 2 if fmax is None else fmax + # mel_basis = librosa.filters.mel(sampling_rate, fft_size, num_mels, fmin, fmax) + mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax) + + return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) + +def extract_melspec(src_filepath, dst_filepath, kwargs): + # print('what') + # try: + warnings.filterwarnings('ignore') + + trim_silence = kwargs['trim_silence'] + top_db = kwargs['top_db'] + flen = kwargs['flen'] + fshift = kwargs['fshift'] + fmin = kwargs['fmin'] + fmax = kwargs['fmax'] + num_mels = kwargs['num_mels'] + fs = kwargs['fs'] + + audio, fs_ = sf.read(src_filepath) + if trim_silence: + # print('trimming.') + audio, _ = librosa.effects.trim(audio, top_db=top_db, frame_length=2048, hop_length=512) + # print('xzz') + if fs != fs_: + # print('resampling.') + # audio = librosa.resample(audio, fs_, fs) + audio = librosa.resample(y=audio, orig_sr=fs_, target_sr=fs) + + melspec_raw = logmelfilterbank(audio,fs, fft_size=flen,hop_size=fshift, + fmin=fmin, fmax=fmax, num_mels=num_mels) + melspec_raw = melspec_raw.astype(np.float32) + melspec_raw = melspec_raw.T # n_mels x n_frame + if not os.path.exists(os.path.dirname(dst_filepath)): + os.makedirs(os.path.dirname(dst_filepath), exist_ok=True) + with h5py.File(dst_filepath, "w") as f: + f.create_dataset("melspec", data=melspec_raw) + logging.info(f"{dst_filepath}...[{melspec_raw.shape}].") + + +def makedirs_if_not_exists(dir): + if not os.path.exists(dir): + os.makedirs(dir) + + +def comb(N, r): + iterable = list(range(0, N)) + return list(itertools.combinations(iterable, 2)) + + +src = './dataset/vctk/data' +dst = './models/stargan/dump/arctic/feat/train' +ext = '.wav' + +fmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s' +datafmt = '%m/%d/%Y %I:%M:%S' +logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datafmt) + +data_config = { + 'num_mels': 80, + 'fs': 16000, + 'flen': 1024, + 'fshift': 128, + 'fmin': 80, + 'fmax': 7600, + 'trim_silence': True, + 'top_db': 30 +} +configpath = './models/stargan/dump/arctic/data_config.json' +if not os.path.exists(os.path.dirname(configpath)): + os.makedirs(os.path.dirname(configpath)) +with open(configpath, 'w') as outfile: + json.dump(data_config, outfile, indent=4) + +fargs_list = [ + [ + f, + f.replace(src, dst).replace(ext, ".h5"), + data_config, + ] + for f in walk_files(src, ext) +] +print(fargs_list) +results = joblib.Parallel(n_jobs=1)( + joblib.delayed(extract_melspec)(*f) for f in tqdm(fargs_list) +) + + +def walk_files(root, extension): + for path, dirs, files in os.walk(root): + for file in files: + if file.endswith(extension): + yield os.path.join(path, file) + + +def melspec_transform(melspec, scaler): + melspec = scaler.transform(melspec.T) + melspec = melspec.T + return melspec + +def normalize_features(src_filepath, dst_filepath, melspec_transform): + try: + with h5py.File(src_filepath, "r") as f: + melspec = f["melspec"][()] + melspec = melspec_transform(melspec) + + if not os.path.exists(os.path.dirname(dst_filepath)): + os.makedirs(os.path.dirname(dst_filepath), exist_ok=True) + with h5py.File(dst_filepath, "w") as f: + f.create_dataset("melspec", data=melspec) + + # logging.info(f"{dst_filepath}...[{melspec.shape}].") + return melspec.shape + + except: + logging.info(f"{dst_filepath}...failed.") + +# parser.add_argument('--src', type=str, +# default='./models/stargan/dump/arctic/feat/train', +# help='data folder that contains the raw features extracted from VoxCeleb2 Dataset') +# parser.add_argument('--dst', type=str, default='./models/stargan/dump/arctic/norm_feat/train', +# help='data folder where the normalized features are stored') +# parser.add_argument('--stat', type=str, default='./models/stargan/dump/arctic/stat.pkl', +# help='state file used for normalization') +# parser.add_argument('--ext', type=str, default='.h5') + + +src = './models/stargan/dump/arctic/feat/train' +dst = './models/stargan/dump/arctic/norm_feat/train' +ext = '.h5' +stat_filepath = './models/stargan/dump/arctic/stat.pkl' + +fmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s' +datafmt = '%m/%d/%Y %I:%M:%S' +logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datafmt) + +melspec_scaler = StandardScaler() +if os.path.exists(stat_filepath): + with open(stat_filepath, mode='rb') as f: + melspec_scaler = pickle.load(f) + print('Loaded mel-spectrogram statistics successfully.') +else: + print('Stat file not found.') + +root = src +fargs_list = [ + [ + f, + f.replace(src, dst), + lambda x: melspec_transform(x, melspec_scaler), + ] + for f in walk_files(root, ext) +] +print(fargs_list) +results = joblib.Parallel(n_jobs=1)( + joblib.delayed(normalize_features)(*f) for f in tqdm(fargs_list) +) + + +def walk_files(root, extension): + for path, dirs, files in os.walk(root): + for file in files: + if file.endswith(extension): + yield os.path.join(path, file) + +def read_melspec(filepath): + with h5py.File(filepath, "r") as f: + melspec = f["melspec"][()] # n_mels x n_frame + # import pdb;pdb.set_trace() # Breakpoint + return melspec + + +def compute_statistics(src, stat_filepath): + melspec_scaler = StandardScaler() + + filepath_list = list(walk_files(src, '.h5')) + for filepath in tqdm(filepath_list): + melspec = read_melspec(filepath) + # import pdb;pdb.set_trace() # Breakpoint + melspec_scaler.partial_fit(melspec.T) + + with open(stat_filepath, mode='wb') as f: + pickle.dump(melspec_scaler, f) + + +fmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s' +datafmt = '%m/%d/%Y %I:%M:%S' +logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datafmt) + +src = './models/stargan/dump/arctic/feat/train' +stat_filepath = './models/stargan/dump/arctic/stat.pkl' +if not os.path.exists(os.path.dirname(stat_filepath)): + os.makedirs(os.path.dirname(stat_filepath)) + +compute_statistics(src, stat_filepath) + +gpu = -1 +data_rootdir = './models/stargan/dump/arctic/norm_feat/train' +epochs = 2000 +snapshot = 200 +batch_size = 12 +num_mels = 80 +arch_type = 'conv' +loss_type = 'wgan' +zdim = 16 +hdim = 64 +mdim = 32 +sdim = 16 +lrate_g = 0.0005 +lrate_d = 5e-6 +gradient_clip = 1.0 +w_adv = 1.0 +w_grad = 1.0 +w_cls = 1.0 +w_cyc = 1.0 +w_rec = 1.0 +normtype = 'IN' +src_conditioning = 0 +resume = 0 +model_rootdir = './model/arctic/' +log_dir = './logs/arctic/' +experiment_name = 'experiment1' + + + + +def train_stargan(): + # Set up GPU + if torch.cuda.is_available() and gpu >= 0: + device = torch.device('cuda:%d' % gpu) + else: + device = torch.device('cpu') + if device.type == 'cuda': + torch.cuda.set_device(device) + + + spk_list = sorted(os.listdir(data_rootdir)) + n_spk = len(spk_list) + melspec_dirs = [os.path.join(data_rootdir, spk) for spk in spk_list] + + config = { + 'num_mels': num_mels, + 'arch_type': arch_type, + 'loss_type': loss_type, + 'zdim': zdim, + 'hdim': hdim, + 'mdim': mdim, + 'sdim': sdim, + 'w_adv': 1.0, + 'w_grad': 1.0, + 'w_cls': 1.0, + 'w_cyc': 1.0, + 'w_rec': 1.0, + 'lrate_g': lrate_g, + 'lrate_d': lrate_d, + 'gradient_clip': 1.0, + 'normtype': normtype, + 'epochs': epochs, + 'BatchSize': batch_size, + 'n_spk': n_spk, + 'spk_list': spk_list, + 'src_conditioning': src_conditioning + } + + model_dir = os.path.join(model_rootdir, experiment_name) + makedirs_if_not_exists(model_dir) + log_path = os.path.join(log_dir, experiment_name, 'train_{}.log'.format(experiment_name)) + + # Save configuration as a json file + config_path = os.path.join(model_dir, 'model_config.json') + with open(config_path, 'w') as outfile: + json.dump(config, outfile, indent=4) + + if arch_type == 'conv': + gen = net.Generator1(num_mels, n_spk, zdim, hdim, sdim, normtype, src_conditioning) + elif arch_type == 'rnn': + net.Generator2(num_mels, n_spk, zdim, hdim, sdim, src_conditioning=src_conditioning) + dis = net.Discriminator1(num_mels, n_spk, mdim, normtype) + models = { + 'gen': gen, + 'dis': dis + } + models['stargan'] = net.StarGAN(models['gen'], models['dis'], n_spk, loss_type) + + optimizers = { + 'gen': optim.Adam(models['gen'].parameters(), lr=lrate_g, betas=(0.9, 0.999)), + 'dis': optim.Adam(models['dis'].parameters(), lr=lrate_d, betas=(0.5, 0.999)) + } + + for tag in ['gen', 'dis']: + models[tag].to(device).train(mode=True) + + train_dataset = MultiDomain_Dataset(*melspec_dirs) + train_loader = DataLoader(train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=0, + # num_workers=os.cpu_count(), + drop_last=True, + collate_fn=collate_fn) + + fmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s' + datafmt = '%m/%d/%Y %I:%M:%S' + if not os.path.exists(os.path.dirname(log_path)): + os.makedirs(os.path.dirname(log_path)) + logging.basicConfig(filename=log_path, filemode='a', level=logging.INFO, format=fmt, datefmt=datafmt) + writer = SummaryWriter(os.path.dirname(log_path)) + + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + for tag in ['gen', 'dis']: + checkpointpath = os.path.join(model_dir, '{}.{}.pt'.format(resume, tag)) + if os.path.exists(checkpointpath): + checkpoint = torch.load(checkpointpath, map_location=device) + models[tag].load_state_dict(checkpoint['model_state_dict']) + optimizers[tag].load_state_dict(checkpoint['optimizer_state_dict']) + print('{} loaded successfully.'.format(checkpointpath)) + + w_adv = config['w_adv'] + w_grad = config['w_grad'] + w_cls = config['w_cls'] + w_cyc = config['w_cyc'] + w_rec = config['w_rec'] + gradient_clip = config['gradient_clip'] + + print("===================================Training Started===================================") + n_iter = 0 + for epoch in range(resume + 1, epochs + 1): + b = 0 + for X_list in train_loader: + n_spk = len(X_list) + xin = [] + for s in range(n_spk): + xin.append(torch.tensor(X_list[s]).to(device, dtype=torch.float)) + + # List of speaker pairs + spk_pair_list = comb(n_spk, 2) + n_spk_pair = len(spk_pair_list) + + gen_loss_mean = 0 + dis_loss_mean = 0 + advloss_d_mean = 0 + gradloss_d_mean = 0 + advloss_g_mean = 0 + clsloss_d_mean = 0 + clsloss_g_mean = 0 + cycloss_mean = 0 + recloss_mean = 0 + # Iterate through all speaker pairs + for m in range(n_spk_pair): + s0 = spk_pair_list[m][0] + s1 = spk_pair_list[m][1] + + AdvLoss_g, ClsLoss_g, CycLoss, RecLoss = models['stargan'].calc_gen_loss(xin[s0], xin[s1], s0, s1) + gen_loss = (w_adv * AdvLoss_g + w_cls * ClsLoss_g + w_cyc * CycLoss + w_rec * RecLoss) + + models['gen'].zero_grad() + gen_loss.backward() + torch.nn.utils.clip_grad_norm_(models['gen'].parameters(), gradient_clip) + optimizers['gen'].step() + + AdvLoss_d, GradLoss_d, ClsLoss_d = models['stargan'].calc_dis_loss(xin[s0], xin[s1], s0, s1) + dis_loss = w_adv * AdvLoss_d + w_grad * GradLoss_d + w_cls * ClsLoss_d + + models['dis'].zero_grad() + dis_loss.backward() + torch.nn.utils.clip_grad_norm_(models['dis'].parameters(), gradient_clip) + optimizers['dis'].step() + + gen_loss_mean += gen_loss.item() + dis_loss_mean += dis_loss.item() + advloss_d_mean += AdvLoss_d.item() + gradloss_d_mean += GradLoss_d.item() + advloss_g_mean += AdvLoss_g.item() + clsloss_d_mean += ClsLoss_d.item() + clsloss_g_mean += ClsLoss_g.item() + cycloss_mean += CycLoss.item() + recloss_mean += RecLoss.item() + + gen_loss_mean /= n_spk_pair + dis_loss_mean /= n_spk_pair + advloss_d_mean /= n_spk_pair + gradloss_d_mean /= n_spk_pair + advloss_g_mean /= n_spk_pair + clsloss_d_mean /= n_spk_pair + clsloss_g_mean /= n_spk_pair + cycloss_mean /= n_spk_pair + recloss_mean /= n_spk_pair + + logging.info( + 'epoch {}, mini-batch {}: AdvLoss_d={:.4f}, AdvLoss_g={:.4f}, GradLoss_d={:.4f}, ClsLoss_d={:.4f}, ClsLoss_g={:.4f}' + .format(epoch, b + 1, w_adv * advloss_d_mean, w_adv * advloss_g_mean, w_grad * gradloss_d_mean, + w_cls * clsloss_d_mean, w_cls * clsloss_g_mean)) + logging.info( + 'epoch {}, mini-batch {}: CycLoss={:.4f}, RecLoss={:.4f}'.format(epoch, b + 1, w_cyc * cycloss_mean, + w_rec * recloss_mean)) + writer.add_scalars('Loss/Total_Loss', {'adv_loss_d': w_adv * advloss_d_mean, + 'adv_loss_g': w_adv * advloss_g_mean, + 'grad_loss_d': w_grad * gradloss_d_mean, + 'cls_loss_d': w_cls * clsloss_d_mean, + 'cls_loss_g': w_cls * clsloss_g_mean, + 'cyc_loss': w_cyc * cycloss_mean, + 'rec_loss': w_rec * recloss_mean}, n_iter) + n_iter += 1 + b += 1 + + if epoch % snapshot == 0: + for tag in ['gen', 'dis']: + print('save {} at {} epoch'.format(tag, epoch)) + torch.save({'epoch': epoch, + 'model_state_dict': models[tag].state_dict(), + 'optimizer_state_dict': optimizers[tag].state_dict()}, + os.path.join(model_dir, '{}.{}.pt'.format(epoch, tag))) + + print("===================================Training Finished===================================") + + +train_stargan() +# Train(models, epochs, train_dataset, train_loader, optimizers, device, model_dir, log_path, model_config, snapshot, +# resume) \ No newline at end of file diff --git a/models/stargan/useless.txt b/models/stargan/useless.txt new file mode 100644 index 00000000..e69de29b diff --git a/module.py b/module.py new file mode 100644 index 00000000..f88b3bee --- /dev/null +++ b/module.py @@ -0,0 +1,136 @@ +# Copyright 2021 Hirokazu Kameoka + +import torch +import torch.nn as nn + +def calc_padding(kernel_size, dilation, causal, stride=1): + if causal: + padding = (kernel_size-1)*dilation+1-stride + else: + padding = ((kernel_size-1)*dilation+1-stride)//2 + return padding + +class LinearWN(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True): + super(LinearWN, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + #nn.init.xavier_normal_(self.linear_layer.weight,gain=0.1) + self.linear_layer = nn.utils.weight_norm(self.linear_layer) + + def forward(self, x): + return self.linear_layer(x) + +class ConvGLU1D(nn.Module): + def __init__(self, in_ch, out_ch, ks, sd, normtype='IN'): + super(ConvGLU1D, self).__init__() + self.conv1 = nn.Conv1d( + in_ch, out_ch*2, ks, stride=sd, padding=(ks-sd)//2) + nn.init.xavier_normal_(self.conv1.weight,gain=0.1) + if normtype=='BN': + self.norm1 = nn.BatchNorm1d(out_ch*2) + elif normtype=='IN': + self.norm1 = nn.InstanceNorm1d(out_ch*2) + elif normtype=='LN': + self.norm1 = nn.LayerNorm(out_ch*2) + + self.conv1 = nn.utils.weight_norm(self.conv1) + self.normtype = normtype + + def __call__(self, x): + h = self.conv1(x) + if self.normtype=='BN' or self.normtype=='IN': + h = self.norm1(h) + elif self.normtype=='LN': + B, D, N = h.shape + h = h.permute(0,2,1).reshape(-1,D) + h = self.norm1(h) + h = h.reshape(B,N,D).permute(0,2,1) + h_l, h_g = torch.split(h, h.shape[1]//2, dim=1) + h = h_l * torch.sigmoid(h_g) + + return h + +class DeconvGLU1D(nn.Module): + def __init__(self, in_ch, out_ch, ks, sd, normtype='IN'): + super(DeconvGLU1D, self).__init__() + self.conv1 = nn.ConvTranspose1d( + in_ch, out_ch*2, ks, stride=sd, padding=(ks-sd)//2) + nn.init.xavier_normal_(self.conv1.weight,gain=0.1) + if normtype=='BN': + self.norm1 = nn.BatchNorm1d(out_ch*2) + elif normtype=='IN': + self.norm1 = nn.InstanceNorm1d(out_ch*2) + elif normtype=='LN': + self.norm1 = nn.LayerNorm(out_ch*2) + + self.conv1 = nn.utils.weight_norm(self.conv1) + self.normtype = normtype + + def __call__(self, x): + h = self.conv1(x) + if self.normtype=='BN' or self.normtype=='IN': + h = self.norm1(h) + elif self.normtype=='LN': + B, D, N = h.shape + h = h.permute(0,2,1).reshape(-1,D) + h = self.norm1(h) + h = h.reshape(B,N,D).permute(0,2,1) + h_l, h_g = torch.split(h, h.shape[1]//2, dim=1) + h = h_l * torch.sigmoid(h_g) + + return h + +class PixelShuffleGLU1D(nn.Module): + def __init__(self, in_ch, out_ch, ks, sd, normtype='IN'): + super(PixelShuffleGLU1D, self).__init__() + self.conv1 = nn.Conv1d( + in_ch, out_ch*2*sd, ks, stride=1, padding=(ks-1)//2) + self.r = sd + if normtype=='BN': + self.norm1 = nn.BatchNorm1d(out_ch*2) + elif normtype=='IN': + self.norm1 = nn.InstanceNorm1d(out_ch*2) + self.conv1 = nn.utils.weight_norm(self.conv1) + self.normtype = normtype + + def __call__(self, x): + h = self.conv1(x) + N, pre_ch, pre_len = h.shape + r = self.r + post_ch = pre_ch//r + post_len = pre_len * r + h = torch.reshape(h, (N, r, post_ch, pre_len)) + h = h.permute(0,2,3,1) + + h = torch.reshape(h, (N, post_ch, post_len)) + if self.normtype=='BN' or self.normtype=='IN': + h = self.norm1(h) + h_l, h_g = torch.split(h, h.shape[1]//2, dim=1) + h = h_l * torch.sigmoid(h_g) + + return h + +def concat_dim1(x,y): + assert x.shape[0] == y.shape[0] + if torch.Tensor.dim(x) == 3: + y0 = torch.unsqueeze(y,2) + N, n_ch, n_t = x.shape + yy = y0.repeat(1,1,n_t) + h = torch.cat((x,yy), dim=1) + elif torch.Tensor.dim(x) == 4: + y0 = torch.unsqueeze(torch.unsqueeze(y,2),3) + N, n_ch, n_q, n_t = x.shape + yy = y0.repeat(1,1,n_q,n_t) + h = torch.cat((x,yy), dim=1) + return h + +def concat_dim2(x,y): + assert x.shape[0] == y.shape[0] + if torch.Tensor.dim(x) == 3: + y0 = torch.unsqueeze(y,1) + N, n_t, n_ch = x.shape + yy = y0.repeat(1,n_t,1) + h = torch.cat((x,yy), dim=2) + elif torch.Tensor.dim(x) == 2: + h = torch.cat((x,y), dim=1) + return h \ No newline at end of file diff --git a/net.py b/net.py new file mode 100644 index 00000000..60531dea --- /dev/null +++ b/net.py @@ -0,0 +1,396 @@ +import numpy as np +import six +import torch +import torch.nn as nn +import torch.nn.functional as F +import math +import module as md + +class Generator1(nn.Module): + # 1D convolutional architecture + def __init__(self, in_ch, n_spk, z_ch, mid_ch, s_ch, normtype='IN', src_conditioning=False): + super(Generator1, self).__init__() + add_ch = 0 if src_conditioning==False else s_ch + self.le1 = md.ConvGLU1D(in_ch+add_ch, mid_ch, 9, 1, normtype) + self.le2 = md.ConvGLU1D(mid_ch+add_ch, mid_ch, 8, 2, normtype) + self.le3 = md.ConvGLU1D(mid_ch+add_ch, mid_ch, 8, 2, normtype) + self.le4 = md.ConvGLU1D(mid_ch+add_ch, mid_ch, 5, 1, normtype) + self.le5 = md.ConvGLU1D(mid_ch+add_ch, z_ch, 5, 1, normtype) + self.le6 = md.DeconvGLU1D(z_ch+s_ch, mid_ch, 5, 1, normtype) + self.le7 = md.DeconvGLU1D(mid_ch+s_ch, mid_ch, 5, 1, normtype) + self.le8 = md.DeconvGLU1D(mid_ch+s_ch, mid_ch, 8, 2, normtype) + self.le9 = md.DeconvGLU1D(mid_ch+s_ch, mid_ch, 8, 2, normtype) + self.le10 = nn.Conv1d(mid_ch+s_ch, in_ch, 9, stride=1, padding=(9-1)//2) + #nn.init.xavier_normal_(self.le10.weight,gain=0.1) + + if src_conditioning: + self.eb0 = nn.Embedding(n_spk, s_ch) + self.eb1 = nn.Embedding(n_spk, s_ch) + self.src_conditioning = src_conditioning + + def __call__(self, xin, k_t, k_s=None): + device = xin.device + B, n_mels, n_frame_ = xin.shape + + kk_t = k_t*torch.ones(B).to(device, dtype=torch.int64) + trgspk_emb = self.eb1(kk_t) + if self.src_conditioning: + kk_s = k_s*torch.ones(B).to(device, dtype=torch.int64) + srcspk_emb = self.eb0(kk_s) + + out = xin + + if self.src_conditioning: out = md.concat_dim1(out,srcspk_emb) + out = self.le1(out) + if self.src_conditioning: + out = md.concat_dim1(out,srcspk_emb) + out = self.le2(out) + if self.src_conditioning: + out = md.concat_dim1(out,srcspk_emb) + out = self.le3(out) + if self.src_conditioning: + out = md.concat_dim1(out,srcspk_emb) + out = self.le4(out) + if self.src_conditioning: + out = md.concat_dim1(out,srcspk_emb) + out = self.le5(out) + out = md.concat_dim1(out,trgspk_emb) + out = self.le6(out) + out = md.concat_dim1(out,trgspk_emb) + out = self.le7(out) + out = md.concat_dim1(out,trgspk_emb) + out = self.le8(out) + out = md.concat_dim1(out,trgspk_emb) + out = self.le9(out) + out = md.concat_dim1(out,trgspk_emb) + out = self.le10(out) + + return out + +class Generator2(nn.Module): + # Bidirectional LSTM + def __init__(self, in_ch, n_spk, z_ch, mid_ch, s_ch, num_layers=2, negative_slope=0.1, src_conditioning=False): + super(Generator2, self).__init__() + add_ch = 0 if src_conditioning==False else s_ch + + self.linear0 = md.LinearWN(in_ch+add_ch, mid_ch) + self.lrelu0 = nn.LeakyReLU(negative_slope) + self.rnn0 = nn.LSTM( + mid_ch+add_ch, + mid_ch//2, + num_layers, + dropout=0, + bidirectional=True, + batch_first = True + ) + self.linear1 = md.LinearWN(mid_ch+add_ch, z_ch) + self.linear2 = md.LinearWN(z_ch+s_ch, mid_ch) + self.lrelu1 = nn.LeakyReLU(negative_slope) + self.rnn1 = nn.LSTM( + mid_ch+s_ch, + mid_ch//2, + num_layers, + dropout=0, + bidirectional=True, + batch_first = True + ) + self.linear3 = md.LinearWN(mid_ch+s_ch, in_ch) + + if src_conditioning: + self.eb0 = nn.Embedding(n_spk, s_ch) + self.eb1 = nn.Embedding(n_spk, s_ch) + self.src_conditioning = src_conditioning + + def __call__(self, xin, k_t, k_s=None): + device = xin.device + B, num_mels, num_frame = xin.shape + kk_t = k_t*torch.ones(B).to(device, dtype=torch.int64) + trgspk_emb = self.eb1(kk_t) + if self.src_conditioning: + kk_s = k_s*torch.ones(B).to(device, dtype=torch.int64) + srcspk_emb = self.eb0(kk_s) + out = xin + + out = out.permute(0,2,1) # (B, num_frame, num_mels) + if self.src_conditioning: out = md.concat_dim2(out,srcspk_emb) # (B, num_frame, num_mels+add_ch) + out = self.lrelu0(self.linear0(out)) + if self.src_conditioning: out = md.concat_dim2(out,srcspk_emb) # (B, num_frame, mid_ch+add_ch) + self.rnn0.flatten_parameters() + out, _ = self.rnn0(out) # (B, num_frame, mid_ch) + if self.src_conditioning: out = md.concat_dim2(out,srcspk_emb) # (B, num_frame, mid_ch+add_ch) + out = self.linear1(out) # (B, num_frame, z_ch) + out = md.concat_dim2(out,trgspk_emb) # (B, num_frame, z_ch+s_ch) + out = self.lrelu1(self.linear2(out)) # (B, num_frame, mid_ch) + out = md.concat_dim2(out,trgspk_emb) # (B, num_frame, mid_ch+s_ch) + self.rnn1.flatten_parameters() + out, _ = self.rnn1(out) # (B, num_frame, mid_ch) + out = md.concat_dim2(out,trgspk_emb) # (B, num_frame, mid_ch+s_ch) + out = self.linear3(out) # (B, num_frame, in_ch) + out = out.permute(0,2,1) # (B, in_ch, num_frame) + + return out + +class Discriminator1(nn.Module): + # 1D convolutional architecture + def __init__(self, in_ch, clsnum, mid_ch, normtype='IN', dor=0.1): + super(Discriminator1, self).__init__() + self.le1 = md.ConvGLU1D(in_ch, mid_ch, 9, 1, normtype) + self.le2 = md.ConvGLU1D(mid_ch, mid_ch, 8, 2, normtype) + self.le3 = md.ConvGLU1D(mid_ch, mid_ch, 8, 2, normtype) + self.le4 = md.ConvGLU1D(mid_ch, mid_ch, 5, 1, normtype) + self.le_adv = nn.Conv1d(mid_ch, 1, 5, stride=1, padding=(5-1)//2, bias=False) + self.le_cls = nn.Conv1d(mid_ch, clsnum, 5, stride=1, padding=(5-1)//2, bias=False) + nn.init.xavier_normal_(self.le_adv.weight,gain=0.1) + nn.init.xavier_normal_(self.le_cls.weight,gain=0.1) + self.do1 = nn.Dropout(p=dor) + self.do2 = nn.Dropout(p=dor) + self.do3 = nn.Dropout(p=dor) + self.do4 = nn.Dropout(p=dor) + + def __call__(self, xin): + device = xin.device + B, n_mels, n_frame_ = xin.shape + + out = xin + + out = self.do1(self.le1(out)) + out = self.do2(self.le2(out)) + out = self.do3(self.le3(out)) + out = self.do4(self.le4(out)) + + out_adv = self.le_adv(out) + out_cls = self.le_cls(out) + + return out_adv, out_cls + +class StarGAN(nn.Module): + def __init__(self, gen, dis, n_spk, loss_type='wgan'): + super(StarGAN, self).__init__() + self.gen = gen + self.dis = dis + self.n_spk = n_spk + self.loss_type = loss_type + + def forward(self, x, k_t, k_s=None): + device = x.device + n_frame_ = x.shape[2] + n_frame = math.ceil(n_frame_/4)*4 + if n_frame > n_frame_: + x = nn.ReplicationPad1d((0, n_frame-n_frame_))(x) + return self.gen(x, k_t, k_s)[:,:,0:n_frame_] + + def calc_advloss_g(self, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts): + df_adv_ss = df_adv_ss.permute(0,2,1).reshape(-1,1) + df_adv_st = df_adv_st.permute(0,2,1).reshape(-1,1) + df_adv_tt = df_adv_tt.permute(0,2,1).reshape(-1,1) + df_adv_ts = df_adv_ts.permute(0,2,1).reshape(-1,1) + + if self.loss_type=='wgan': + # Wasserstein GAN with gradient penalty (WGAN-GP) + AdvLoss_g = ( + torch.sum(-df_adv_ss) + + torch.sum(-df_adv_st) + + torch.sum(-df_adv_tt) + + torch.sum(-df_adv_ts) + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + + elif self.loss_type=='lsgan': + # Least squares GAN (LSGAN) + AdvLoss_g = 0.5 * ( + torch.sum((df_adv_ss - torch.ones_like(df_adv_ss))**2) + + torch.sum((df_adv_st - torch.ones_like(df_adv_st))**2) + + torch.sum((df_adv_tt - torch.ones_like(df_adv_tt))**2) + + torch.sum((df_adv_ts - torch.ones_like(df_adv_ts))**2) + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + + elif self.loss_type=='cgan': + # Regular GAN with the sigmoid cross-entropy criterion (CGAN) + AdvLoss_g = ( + F.binary_cross_entropy_with_logits(df_adv_ss, torch.ones_like(df_adv_ss), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_st, torch.ones_like(df_adv_st), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_tt, torch.ones_like(df_adv_tt), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_ts, torch.ones_like(df_adv_ts), reduction='sum') + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + + return AdvLoss_g + + def calc_clsloss_g(self, df_cls_ss, df_cls_st, df_cls_tt, df_cls_ts, k_s, k_t): + device = df_cls_ss.device + + df_cls_ss = df_cls_ss.permute(0,2,1).reshape(-1,self.n_spk) + df_cls_st = df_cls_st.permute(0,2,1).reshape(-1,self.n_spk) + df_cls_tt = df_cls_tt.permute(0,2,1).reshape(-1,self.n_spk) + df_cls_ts = df_cls_ts.permute(0,2,1).reshape(-1,self.n_spk) + + cf_ss = k_s*torch.ones(len(df_cls_ss), device=device, dtype=torch.long) + cf_st = k_t*torch.ones(len(df_cls_st), device=device, dtype=torch.long) + cf_tt = k_t*torch.ones(len(df_cls_tt), device=device, dtype=torch.long) + cf_ts = k_s*torch.ones(len(df_cls_ts), device=device, dtype=torch.long) + + ClsLoss_g = ( + F.cross_entropy(df_cls_ss, cf_ss, reduction='sum') + + F.cross_entropy(df_cls_st, cf_st, reduction='sum') + + F.cross_entropy(df_cls_tt, cf_tt, reduction='sum') + + F.cross_entropy(df_cls_ts, cf_ts, reduction='sum') + ) / (df_cls_ss.numel() + df_cls_st.numel() + df_cls_tt.numel() + df_cls_ts.numel()) + + return ClsLoss_g + + def calc_advloss_d(self, x_s, x_t, xf_ts, xf_st, dr_adv_s, dr_adv_t, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts): + device = x_s.device + B_s = len(x_s) + B_t = len(x_t) + + dr_adv_s = dr_adv_s.permute(0,2,1).reshape(-1,1) + dr_adv_t = dr_adv_t.permute(0,2,1).reshape(-1,1) + df_adv_ss = df_adv_ss.permute(0,2,1).reshape(-1,1) + df_adv_st = df_adv_st.permute(0,2,1).reshape(-1,1) + df_adv_tt = df_adv_tt.permute(0,2,1).reshape(-1,1) + df_adv_ts = df_adv_ts.permute(0,2,1).reshape(-1,1) + + if self.loss_type=='wgan': + # Wasserstein GAN with gradient penalty (WGAN-GP) + AdvLoss_d_r = ( + torch.sum(-dr_adv_s) + + torch.sum(-dr_adv_t) + ) / (dr_adv_s.numel() + dr_adv_t.numel()) + AdvLoss_d_f = ( + torch.sum(df_adv_ss) + + torch.sum(df_adv_st) + + torch.sum(df_adv_tt) + + torch.sum(df_adv_ts) + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f + + elif self.loss_type=='lsgan': + # Least squares GAN (LSGAN) + + AdvLoss_d_r = 0.5 * ( + torch.sum((dr_adv_s - torch.ones_like(dr_adv_s))**2) + + torch.sum((dr_adv_t - torch.ones_like(dr_adv_t))**2) + ) / (dr_adv_s.numel() + dr_adv_t.numel()) + AdvLoss_d_f = 0.5 * ( + torch.sum(df_adv_ss**2) + + torch.sum(df_adv_st**2) + + torch.sum(df_adv_tt**2) + + torch.sum(df_adv_ts**2) + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f + + elif self.loss_type=='cgan': + # Regular GAN with sigmoid cross-entropy criterion (CGAN) + AdvLoss_d_r = ( + F.binary_cross_entropy_with_logits(dr_adv_s, torch.ones_like(dr_adv_s), reduction='sum') + + F.binary_cross_entropy_with_logits(dr_adv_t, torch.ones_like(dr_adv_t), reduction='sum') + ) / (dr_adv_s.numel() + dr_adv_t.numel()) + AdvLoss_d_f = ( + F.binary_cross_entropy_with_logits(df_adv_ss, torch.zeros_like(df_adv_ss), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_st, torch.zeros_like(df_adv_st), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_tt, torch.zeros_like(df_adv_tt), reduction='sum') + + F.binary_cross_entropy_with_logits(df_adv_ts, torch.zeros_like(df_adv_ts), reduction='sum') + ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel()) + AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f + + # Gradient penalty loss + alpha_t = torch.rand(B_t, 1, 1, requires_grad=True).to(device) + interpolates = alpha_t * x_t + ((1 - alpha_t) * xf_ts) + interpolates = interpolates.to(device) + disc_interpolates, _ = self.dis(interpolates) + disc_interpolates = torch.sum(disc_interpolates) + gradients = torch.autograd.grad(outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + gradnorm = torch.sqrt(torch.sum(gradients * gradients, (1, 2))) + loss_gp_t = ((gradnorm - 1)**2).mean() + + alpha_s = torch.rand(B_s, 1, 1, requires_grad=True).to(device) + interpolates = alpha_s * x_s + ((1 - alpha_s) * xf_st) + interpolates = interpolates.to(device) + interpolates = torch.autograd.Variable(interpolates, requires_grad=True) + disc_interpolates, _ = self.dis(interpolates) + disc_interpolates = torch.sum(disc_interpolates) + gradients = torch.autograd.grad(outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True)[0] + gradnorm = torch.sqrt(torch.sum(gradients * gradients, (1, 2))) + loss_gp_s = ((gradnorm - 1)**2).mean() + + GradLoss_d = loss_gp_s + loss_gp_t + + return AdvLoss_d, GradLoss_d + + def calc_clsloss_d(self, dr_cls_s, dr_cls_t, k_s, k_t): + device = dr_cls_s.device + + dr_cls_s = dr_cls_s.permute(0,2,1).reshape(-1,self.n_spk) + dr_cls_t = dr_cls_t.permute(0,2,1).reshape(-1,self.n_spk) + + cr_s = k_s*torch.ones(len(dr_cls_s), device=device, dtype=torch.long) + cr_t = k_t*torch.ones(len(dr_cls_t), device=device, dtype=torch.long) + + ClsLoss_d = ( + F.cross_entropy(dr_cls_s, cr_s, reduction='sum') + + F.cross_entropy(dr_cls_t, cr_t, reduction='sum') + ) / (dr_cls_s.numel() + dr_cls_t.numel()) + + return ClsLoss_d + + def calc_gen_loss(self, x_s, x_t, k_s, k_t): + # Generator outputs + xf_ss = self.gen(x_s, k_s, k_s) + xf_ts = self.gen(x_t, k_s, k_t) + xf_tt = self.gen(x_t, k_t, k_t) + xf_st = self.gen(x_s, k_t, k_s) + + # Discriminator outputs + df_adv_ss, df_cls_ss = self.dis(xf_ss) + df_adv_st, df_cls_st = self.dis(xf_st) + df_adv_tt, df_cls_tt = self.dis(xf_tt) + df_adv_ts, df_cls_ts = self.dis(xf_ts) + + # Adversarial loss + AdvLoss_g = self.calc_advloss_g(df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts) + + # Classifier loss + ClsLoss_g = self.calc_clsloss_g(df_cls_ss, df_cls_st, df_cls_tt, df_cls_ts, k_s, k_t) + + # Cycle-consistency loss + CycLoss = ( + torch.sum(torch.abs(x_s - self.gen(xf_st, k_s, k_t))) + + torch.sum(torch.abs(x_t - self.gen(xf_ts, k_t, k_s))) + ) / (x_s.numel() + x_t.numel()) + + # Reconstruction loss + RecLoss = ( + torch.sum(torch.abs(x_s - xf_ss)) + + torch.sum(torch.abs(x_t - xf_tt)) + ) / (x_s.numel() + x_t.numel()) + + return AdvLoss_g, ClsLoss_g, CycLoss, RecLoss + + def calc_dis_loss(self, x_s, x_t, k_s, k_t): + device = x_s.device + + # Generator outputs + xf_ss = self.gen(x_s, k_s, k_s) + xf_ts = self.gen(x_t, k_s, k_t) + xf_tt = self.gen(x_t, k_t, k_t) + xf_st = self.gen(x_s, k_t, k_s) + + # Discriminator outputs + dr_adv_s, dr_cls_s = self.dis(x_s) + dr_adv_t, dr_cls_t = self.dis(x_t) + df_adv_ss, _ = self.dis(xf_ss) + df_adv_st, _ = self.dis(xf_st) + df_adv_tt, _ = self.dis(xf_tt) + df_adv_ts, _ = self.dis(xf_ts) + + # Adversarial loss + AdvLoss_d, GradLoss_d = self.calc_advloss_d(x_s, x_t, xf_ts, xf_st, dr_adv_s, dr_adv_t, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts) + + # Classifier loss + ClsLoss_d = self.calc_clsloss_d(dr_cls_s, dr_cls_t, k_s, k_t) + + return AdvLoss_d, GradLoss_d, ClsLoss_d diff --git a/run_talkingface.py b/run_talkingface.py index 3989d566..a257a762 100644 --- a/run_talkingface.py +++ b/run_talkingface.py @@ -1,11 +1,19 @@ import argparse -from talkingface.quick_start import run +# from talkingface.quick_start import run +from models.stargan import train_stargan_model + + +def run(model_name, dataset_name, config_file_list=None, evaluate_model_file=None, train=False): + if train: + if model_name == 'stargan': + train_stargan_model() + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--model", "-m", type=str, default="BPR", help="name of models") + parser.add_argument("--model", "-m", type=str, default="stargan", help="name of models") parser.add_argument( - "--dataset", "-d", type=str, default=None, help="name of datasets" + "--dataset", "-d", type=str, default='vctk', help="name of datasets" ) parser.add_argument("--evaluate_model_file", type=str, default=None, help="The model file you want to evaluate") parser.add_argument("--config_files", type=str, default=None, help="config files") diff --git a/talkingface/data/Arctic.yaml b/talkingface/data/Arctic.yaml new file mode 100644 index 00000000..72a17db6 --- /dev/null +++ b/talkingface/data/Arctic.yaml @@ -0,0 +1,12 @@ +flen: 1024 +fmax: 7600 +fmin: 80 +fs: 16000 +fshift: 128 +num_mels: 80 +top_db: 30 +trim_silence: false +stat_path: "./dataset/vctk/filelist/stat.pkl" +data_path: "./dataset/vctk/data" +model: "stargan_vc" +train_filelist: "./dataset/vctk/filelist/train.txt" \ No newline at end of file diff --git a/talkingface/data/dataprocess/__init__.py b/talkingface/data/dataprocess/__init__.py index 7c7aef87..9a724efa 100644 --- a/talkingface/data/dataprocess/__init__.py +++ b/talkingface/data/dataprocess/__init__.py @@ -1 +1,2 @@ -from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio, Wav2LipPreprocessForInference \ No newline at end of file +from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio, Wav2LipPreprocessForInference +from talkingface.data.dataprocess.stargan_vc_process import StarganAudio \ No newline at end of file diff --git a/talkingface/data/dataprocess/stargan_vc_process.py b/talkingface/data/dataprocess/stargan_vc_process.py new file mode 100644 index 00000000..63c106a7 --- /dev/null +++ b/talkingface/data/dataprocess/stargan_vc_process.py @@ -0,0 +1,84 @@ + +import warnings +import librosa +import numpy as np +import soundfile as sf + + + +class StarganAudio: + def __init__(self, config): + # 初始化配置参数。 + self.kwargs = config + + def logmelfilterbank(self, audio, + sampling_rate, + fft_size=1024, + hop_size=256, + win_length=None, + window="hann", + num_mels=80, + fmin=None, + fmax=None, + eps=1e-10, + ): + """Compute log-Mel filterbank feature. + Args: + audio (ndarray): Audio signal (T,). + sampling_rate (int): Sampling rate. + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. If set to None, it will be the same as fft_size. + window (str): Window function type. + num_mels (int): Number of mel basis. + fmin (int): Minimum frequency in mel basis calculation. + fmax (int): Maximum frequency in mel basis calculation. + eps (float): Epsilon value to avoid inf in log calculation. + Returns: + ndarray: Log Mel filterbank feature (#frames, num_mels). + """ + # 获取振幅频谱 + x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size, + win_length=win_length, window=window, pad_mode="reflect") + spc = np.abs(x_stft).T # (#frames, #bins) + + # 获取梅尔基数 + fmin = 0 if fmin is None else fmin + fmax = sampling_rate / 2 if fmax is None else fmax + mel_basis = librosa.filters.mel(sampling_rate, fft_size, num_mels, fmin, fmax) + + return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T))) + + def extract_melspec(self, src_filepath): + try: + warnings.filterwarnings('ignore') + + # 从配置中提取参数。 + trim_silence = self.kwargs['trim_silence'] + top_db = self.kwargs['top_db'] + flen = self.kwargs['flen'] + fshift = self.kwargs['fshift'] + fmin = self.kwargs['fmin'] + fmax = self.kwargs['fmax'] + num_mels = self.kwargs['num_mels'] + fs = self.kwargs['fs'] + + # 读取音频文件。 + audio, fs_ = sf.read(src_filepath) + if trim_silence: + # 如果需要,剪切静音部分。 + audio, _ = librosa.effects.trim(audio, top_db=top_db, frame_length=2048, hop_length=512) + if fs != fs_: + # 如果需要,进行重采样。 + audio = librosa.resample(audio, fs_, fs) + # 提取梅尔频谱。 + melspec_raw = self.logmelfilterbank(audio,fs, fft_size=flen,hop_size=fshift, + fmin=fmin, fmax=fmax, num_mels=num_mels) + melspec_raw = melspec_raw.astype(np.float32) + melspec_raw = melspec_raw.T # n_mels x n_frame + return melspec_raw + + except: + print(f"{src_filepath}...failed.") + return None + diff --git a/talkingface/data/dataset/__init__.py b/talkingface/data/dataset/__init__.py index 3fd37538..d482fb31 100644 --- a/talkingface/data/dataset/__init__.py +++ b/talkingface/data/dataset/__init__.py @@ -1,2 +1,3 @@ from talkingface.data.dataset.wav2lip_dataset import Wav2LipDataset -from talkingface.data.dataset.dataset import Dataset \ No newline at end of file +from talkingface.data.dataset.dataset import Dataset +from talkingface.data.dataset.stargan_vc_dataset import StarganDataset \ No newline at end of file diff --git a/talkingface/data/dataset/compute_statistics_dataset.py b/talkingface/data/dataset/compute_statistics_dataset.py new file mode 100644 index 00000000..a690c314 --- /dev/null +++ b/talkingface/data/dataset/compute_statistics_dataset.py @@ -0,0 +1,30 @@ +import os +import h5py +from tqdm import tqdm +from sklearn.preprocessing import StandardScaler +import pickle + +def walk_files(root, extension): + for path, dirs, files in os.walk(root): + for file in files: + if file.endswith(extension): + yield os.path.join(path, file) + +def compute_statistics(src, processor, stat_filepath="stat.pkl"): + melspec_scaler = StandardScaler() + filenames_all = [[os.path.join(src,d,t) for t in sorted(os.listdir(os.path.join(src,d)))] for d in os.listdir(src)] + filepath_list = list(walk_files(src, '.h5')) + for filepath_list in tqdm(filenames_all): + for filepath in filepath_list: + # 读取 + melspec = processor.extract_melspec(filepath) + #import pdb;pdb.set_trace() # Breakpoint + melspec_scaler.partial_fit(melspec.T) + + with open(stat_filepath, mode='wb') as f: + pickle.dump(melspec_scaler, f) + print("Saved.") + + +if __name__ == '__main__': + pass \ No newline at end of file diff --git a/talkingface/data/dataset/stargan_vc_dataset.py b/talkingface/data/dataset/stargan_vc_dataset.py new file mode 100644 index 00000000..8d3d8561 --- /dev/null +++ b/talkingface/data/dataset/stargan_vc_dataset.py @@ -0,0 +1,91 @@ +# 导入必要的模块和函数。 + +import os +from talkingface.data.dataprocess.stargan_vc_process import StarganAudio +from talkingface.data.dataset.dataset import Dataset +import math +import numpy as np +from sklearn.preprocessing import StandardScaler +import pickle + + +class StarganDataset(Dataset): + def __init__(self, config, file_list): + self.config = config + # 从配置中获取预处理数据的根目录,并列出其中的所有目录。 + feat_dirs = os.listdir(self.config['preprocessed_root']) + + with open(file_list) as f: + text = f.readlines() + file_list_ = [] + for line in text: + line = line.strip() + file_list_.append(line) + # root下spk数量 + # 为每个目录中的文件创建一个文件名列表。 + self.filenames_all = [ + [ + os.path.join(self.config['preprocessed_root'],d,t) for t in sorted(os.listdir(os.path.join(self.config['preprocessed_root'], d))) if t in file_list_ + ] for d in feat_dirs + ] + self.n_domain = len(self.filenames_all) + self.feat_dirs = feat_dirs + # 创建 StarganAudio 对象以处理音频数据。 + self.process_audio = StarganAudio(config) + + self.melspec_scaler = StandardScaler() + if os.path.exists(config['stat_path']): + with open(config['stat_path'], mode='rb') as f: + self.melspec_scaler = pickle.load(f) + else: + print("Melspec_scaler is None.") + self.melspec_scaler = None + def __len__(self): + return min(len(f) for f in self.filenames_all) + + def __getitem__(self, idx): + melspec_list = [] + # 遍历每个spk的平行wav。 + for d in range(self.n_domain): + # 处理音频文件并获取梅尔频谱。 + # print(self.filenames_all[d][idx]) + melspec = self.process_audio.extract_melspec(self.filenames_all[d][idx]) # n_freq x n_time + if self.melspec_scaler is not None: + melspec = self.melspec_scaler.transform(melspec.T) + # 将梅尔频谱添加到列表中。 + melspec_list.append(melspec.T) + # 返回包含所有领域梅尔频谱的列表。 + return melspec_list + # return {"melspec_list": melspec_list} + + def collate_fn(self, batch): + #batch[b][s]: melspec (n_freq x n_frame) + #b: batch size + #s: speaker ID + + batchsize = len(batch) + n_spk = len(batch[0]) + melspec_list = [[batch[b][s] for b in range(batchsize)] for s in range(n_spk)] + #melspec_list[s][b]: melspec (n_freq x n_frame) + #s: speaker ID + #b: batch size + + n_freq = melspec_list[0][0].shape[0] + + X_list = [] + for s in range(n_spk): + maxlen=0 + for b in range(batchsize): + if maxlen