diff --git a/README.md b/README.md
index 7cd82328..4738901a 100644
--- a/README.md
+++ b/README.md
@@ -1,210 +1,44 @@
-# talkingface-toolkit
-## 框架整体介绍
-### checkpoints
-主要保存的是训练和评估模型所需要的额外的预训练模型,在对应文件夹的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/checkpoints/README.md)有更详细的介绍
+# talkingface-toolkit小组作业
+小组成员名单:高艺芙 贺芳琪 陈清扬(姓名排序为各自工作的前后逻辑顺序,与工作量无关)
+# 模型选择:
+我们选择复现的模型是[StarGAN-VC](https://github.com/kamepong/StarGAN-VC)模型,是专为语音转换任务设计的模型。它是StarGAN框架的延伸,该框架最初用于图像到图像的转换任务。StarGAN-VC专注于将一个说话者的语音特征转换为另一个说话者。
+# 作业环境
+python3.8
-### datset
-存放数据集以及数据集预处理之后的数据,详细内容见dataset里的[README](https://github.com/Academic-Hammer/talkingface-toolkit/blob/main/dataset/README.md)
+开发环境:PyCharm2022.1.3
-### saved
-存放训练过程中保存的模型checkpoint, 训练过程中保存模型时自动创建
+框架:PyTorch2.1
-### talkingface
-主要功能模块,包括所有核心代码
+操作系统:Windows11,macOs
-#### config
-根据模型和数据集名称自动生成所有模型、数据集、训练、评估等相关的配置信息
-```
-config/
+语音识别库:Librosa0.10.1
-├── configurator.py
+数据处理库:NumPy1.20.3
+# 数据集
+包含p225、226、227,分为test、train、val三种,详见dataset文件夹。
+# 运行指令
+总指令:python run_talkingface.py --model stargan --dataset vctk
-```
-#### data
-- dataprocess:模型特有的数据处理代码,(可以是对方仓库自己实现的音频特征提取、推理时的数据处理)。如果实现的模型有这个需求,就要建立一对应的文件
-- dataset:每个模型都要重载`torch.utils.data.Dataset` 用于加载数据。每个模型都要有一个`model_name+'_dataset.py'`文件. `__getitem__()`方法的返回值应处理成字典类型的数据。 (核心部分)
-```
-data/
+前期数据准备指令:./recipes/run_train.sh
-├── dataprocess
+使用数据的指令:例如,当训练时要用对应的数据集,输入StarganDataset(config, config['train_filelist'])能调用train数据集
+# 实现功能
+总体功能:输入python run_talkingface.py --model stargan --dataset vctk可以直接进行数据处理、训练建模。
-| ├── wav2lip_process.py
+前期数据准备:得到了StarGAN模型的配置文件参数,将其运用在toolkit中。
-| ├── xxxx_process.py
-
-├── dataset
-
-| ├── wav2lip_dataset.py
-
-| ├── xxx_dataset.py
-```
-
-#### evaluate
-主要涉及模型评估的代码
-LSE metric 需要的数据是生成的视频列表
-SSIM metric 需要的数据是生成的视频和真实的视频列表
-
-#### model
-实现的模型的网络和对应的方法 (核心部分)
-
-主要分三类:
-- audio-driven (音频驱动)
-- image-driven (图像驱动)
-- nerf-based (基于神经辐射场的方法)
-
-```
-model/
-
-├── audio_driven_talkingface
-
-| ├── wav2lip.py
-
-├── image_driven_talkingface
-
-| ├── xxxx.py
-
-├── nerf_based_talkingface
-
-| ├── xxxx.py
-
-├── abstract_talkingface.py
-
-```
-
-#### properties
-保存默认配置文件,包括:
-- 数据集配置文件
-- 模型配置文件
-- 通用配置文件
-
-需要根据对应模型和数据集增加对应的配置文件,通用配置文件`overall.yaml`一般不做修改
-```
-properties/
-
-├── dataset
-
-| ├── xxx.yaml
-
-├── model
-
-| ├── xxx.yaml
-
-├── overall.yaml
-
-```
-
-#### quick_start
-通用的启动文件,根据传入参数自动配置数据集和模型,然后训练和评估(一般不需要修改)
-```
-quick_start/
-
-├── quick_start.py
-
-```
-
-#### trainer
-训练、评估函数的主类。在trainer中,如果可以使用基类`Trainer`实现所有功能,则不需要写一个新的。如果模型训练有一些特有部分,则需要重载`Trainer`。需要重载部分可能主要集中于: `_train_epoch()`, `_valid_epoch()`。 重载的`Trainer`应该命名为:`{model_name}Trainer`
-```
-trainer/
-
-├── trainer.py
-
-```
-
-#### utils
-公用的工具类,包括`s3fd`人脸检测,视频抽帧、视频抽音频方法。还包括根据参数配置找对应的模型类、数据类等方法。
-一般不需要修改,但可以适当添加一些必须的且相对普遍的数据处理文件。
-
-## 使用方法
-### 环境要求
-- `python=3.8`
-- `torch==1.13.1+cu116`(gpu版,若设备不支持cuda可以使用cpu版)
-- `numpy==1.20.3`
-- `librosa==0.10.1`
-
-尽量保证上面几个包的版本一致
-
-提供了两种配置其他环境的方法:
-```
-pip install -r requirements.txt
-
-or
-
-conda env create -f environment.yml
-```
-
-建议使用conda虚拟环境!!!
-
-### 训练和评估
-
-```bash
-python run_talkingface.py --model=xxxx --dataset=xxxx (--other_parameters=xxxxxx)
-```
-
-### 权重文件
-
-- LSE评估需要的权重: syncnet_v2.model [百度网盘下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc)
-- wav2lip需要的lip expert 权重:lipsync_expert.pth [百度网下载](https://pan.baidu.com/s/1vQoL9FuKlPyrHOGKihtfVA?pwd=32hc)
-
-## 可选论文:
-### Aduio_driven talkingface
-| 模型简称 | 论文 | 代码仓库 |
-|:--------:|:--------:|:--------:|
-| MakeItTalk | [paper](https://arxiv.org/abs/2004.12992) | [code](https://github.com/yzhou359/MakeItTalk) |
-| MEAD | [paper](https://wywu.github.io/projects/MEAD/support/MEAD.pdf) | [code](https://github.com/uniBruce/Mead) |
-| RhythmicHead | [paper](https://arxiv.org/pdf/2007.08547v1.pdf) | [code](https://github.com/lelechen63/Talking-head-Generation-with-Rhythmic-Head-Motion) |
-| PC-AVS | [paper](https://arxiv.org/abs/2104.11116) | [code](https://github.com/Hangz-nju-cuhk/Talking-Face_PC-AVS) |
-| EVP | [paper](https://openaccess.thecvf.com/content/CVPR2021/papers/Ji_Audio-Driven_Emotional_Video_Portraits_CVPR_2021_paper.pdf) | [code](https://github.com/jixinya/EVP) |
-| LSP | [paper](https://arxiv.org/abs/2109.10595) | [code](https://github.com/YuanxunLu/LiveSpeechPortraits) |
-| EAMM | [paper](https://arxiv.org/pdf/2205.15278.pdf) | [code](https://github.com/jixinya/EAMM/) |
-| DiffTalk | [paper](https://arxiv.org/abs/2301.03786) | [code](https://github.com/sstzal/DiffTalk) |
-| TalkLip | [paper](https://arxiv.org/pdf/2303.17480.pdf) | [code](https://github.com/Sxjdwang/TalkLip) |
-| EmoGen | [paper](https://arxiv.org/pdf/2303.11548.pdf) | [code](https://github.com/sahilg06/EmoGen) |
-| SadTalker | [paper](https://arxiv.org/abs/2211.12194) | [code](https://github.com/OpenTalker/SadTalker) |
-| HyperLips | [paper](https://arxiv.org/abs/2310.05720) | [code](https://github.com/semchan/HyperLips) |
-| PHADTF | [paper](http://arxiv.org/abs/2002.10137) | [code](https://github.com/yiranran/Audio-driven-TalkingFace-HeadPose) |
-| VideoReTalking | [paper](https://arxiv.org/abs/2211.14758) | [code](https://github.com/OpenTalker/video-retalking#videoretalking--audio-based-lip-synchronization-for-talking-head-video-editing-in-the-wild-)
-| |
-
-
-
-### Image_driven talkingface
-| 模型简称 | 论文 | 代码仓库 |
-|:--------:|:--------:|:--------:|
-| PIRenderer | [paper](https://arxiv.org/pdf/2109.08379.pdf) | [code](https://github.com/RenYurui/PIRender) |
-| StyleHEAT | [paper](https://arxiv.org/pdf/2203.04036.pdf) | [code](https://github.com/OpenTalker/StyleHEAT) |
-| MetaPortrait | [paper](https://arxiv.org/abs/2212.08062) | [code](https://github.com/Meta-Portrait/MetaPortrait) |
-| |
-### Nerf-based talkingface
-| 模型简称 | 论文 | 代码仓库 |
-|:--------:|:--------:|:--------:|
-| AD-NeRF | [paper](https://arxiv.org/abs/2103.11078) | [code](https://github.com/YudongGuo/AD-NeRF) |
-| GeneFace | [paper](https://arxiv.org/abs/2301.13430) | [code](https://github.com/yerfor/GeneFace) |
-| DFRF | [paper](https://arxiv.org/abs/2207.11770) | [code](https://github.com/sstzal/DFRF) |
-| |
-### text_to_speech
-| 模型简称 | 论文 | 代码仓库 |
-|:--------:|:--------:|:--------:|
-| VITS | [paper](https://arxiv.org/abs/2106.06103) | [code](https://github.com/jaywalnut310/vits) |
-| Glow TTS | [paper](https://arxiv.org/abs/2005.11129) | [code](https://github.com/jaywalnut310/glow-tts) |
-| FastSpeech2 | [paper](https://arxiv.org/abs/2006.04558v1) | [code](https://github.com/ming024/FastSpeech2) |
-| StyleTTS2 | [paper](https://arxiv.org/abs/2306.07691) | [code](https://github.com/yl4579/StyleTTS2) |
-| Grad-TTS | [paper](https://arxiv.org/abs/2105.06337) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS) |
-| FastSpeech | [paper](https://arxiv.org/abs/1905.09263) | [code](https://github.com/xcmyz/FastSpeech) |
-| |
-### voice_conversion
-| 模型简称 | 论文 | 代码仓库 |
-|:--------:|:--------:|:--------:|
-| StarGAN-VC | [paper](http://www.kecl.ntt.co.jp/people/kameoka.hirokazu/Demos/stargan-vc2/index.html) | [code](https://github.com/kamepong/StarGAN-VC) |
-| Emo-StarGAN | [paper](https://www.researchgate.net/publication/373161292_Emo-StarGAN_A_Semi-Supervised_Any-to-Many_Non-Parallel_Emotion-Preserving_Voice_Conversion) | [code](https://github.com/suhitaghosh10/emo-stargan) |
-| adaptive-VC | [paper](https://arxiv.org/abs/1904.05742) | [code](https://github.com/jjery2243542/adaptive_voice_conversion) |
-| DiffVC | [paper](https://arxiv.org/abs/2109.13821) | [code](https://github.com/huawei-noah/Speech-Backbones/tree/main/DiffVC) |
-| Assem-VC | [paper](https://arxiv.org/abs/2104.00931) | [code](https://github.com/maum-ai/assem-vc) |
-| |
-
-## 作业要求
-- 确保可以仅在命令行输入模型和数据集名称就可以训练、验证。(部分仓库没有提供训练代码的,可以不训练)
-- 每个组都要提交一个README文件,写明完成的功能、最终实现的训练、验证截图、所使用的依赖、成员分工等。
+数据部分:按照stargan的源码进行数据预处理,提取mel谱,转化为可用形式;划分数据集,按照80%、10%、10%的比例划分了三个数据集,生成了test.txt、train.txt、val.txt等文件,保存在dataset文件夹中。
+# 结果截图
+
+
+
+
+# 成员分工(按任务先后顺序编写)
+高艺芙-1120213132-07022102:负责配置好实验环境,使用pip install -r requirements.txt语句更新安装包、解决报错问题,训练得到程序运行后的结果。准备实验数据,运行StarGan模型,将得到配置文件Arctic.json,StarGAN.json,将其转换为Arctic.yaml,StarGAN.yaml并整合到相应的文件结构中,整合至talkingface-toolkit/talkingface/properties下,最终打印调试。
+贺芳琪-1120210640-08012101:部分配置调整,主要负责data,数据预处理部分,进行了数据集划分。将stargan-vc中的dataset.py、compute_statistics.py、extract_features.py、normalize_features.py中有关数据预处理的代码整合到talkingface-toolkit/talkingface/data的dataset和dataprocess文件夹中,并修改了talkingface-toolkit中yaml里面有关数据预处理的参数。详细解释可见data部分的readme文件中。
+
+陈清扬-1120213599-07112106:模型代码重构,训练代码重构,推理文件重构,配置文件调整,测试代码与bug修复,撰写实验报告
+
diff --git a/convert.py b/convert.py
new file mode 100644
index 00000000..3fefa46a
--- /dev/null
+++ b/convert.py
@@ -0,0 +1,208 @@
+# Copyright 2021 Hirokazu Kameoka
+
+import os
+import argparse
+import torch
+import json
+import numpy as np
+import re
+import pickle
+from tqdm import tqdm
+import yaml
+
+import librosa
+import soundfile as sf
+from sklearn.preprocessing import StandardScaler
+
+import net
+from extract_features import logmelfilterbank
+
+import sys
+sys.path.append(os.path.abspath("pwg"))
+from pwg.parallel_wavegan.utils import load_model
+from pwg.parallel_wavegan.utils import read_hdf5
+
+def audio_transform(wav_filepath, scaler, kwargs, device):
+
+ trim_silence = kwargs['trim_silence']
+ top_db = kwargs['top_db']
+ flen = kwargs['flen']
+ fshift = kwargs['fshift']
+ fmin = kwargs['fmin']
+ fmax = kwargs['fmax']
+ num_mels = kwargs['num_mels']
+ fs = kwargs['fs']
+
+ audio, fs_ = sf.read(wav_filepath)
+ if trim_silence:
+ #print('trimming.')
+ audio, _ = librosa.effects.trim(audio, top_db=top_db, frame_length=2048, hop_length=512)
+ if fs != fs_:
+ #print('resampling.')
+ audio = librosa.resample(audio, fs_, fs)
+ melspec_raw = logmelfilterbank(audio,fs, fft_size=flen,hop_size=fshift,
+ fmin=fmin, fmax=fmax, num_mels=num_mels)
+ melspec_raw = melspec_raw.astype(np.float32) # n_frame x n_mels
+
+ melspec_norm = scaler.transform(melspec_raw)
+ melspec_norm = melspec_norm.T # n_mels x n_frame
+
+ return torch.tensor(melspec_norm[None]).to(device, dtype=torch.float)
+
+def extract_num(s, p, ret=0):
+ search = p.search(s)
+ if search:
+ return int(search.groups()[0])
+ else:
+ return ret
+
+def listdir_ext(dirpath,ext):
+ p = re.compile(r'(\d+)')
+ out = []
+ for file in sorted(os.listdir(dirpath), key=lambda s: extract_num(s, p)):
+ if os.path.splitext(file)[1]==ext:
+ out.append(file)
+ return out
+
+def find_newest_model_file(model_dir, tag):
+ mfile_list = os.listdir(model_dir)
+ checkpoint = max([int(os.path.splitext(os.path.splitext(mfile)[0])[0]) for mfile in mfile_list if mfile.endswith('.{}.pt'.format(tag))])
+ return checkpoint
+
+
+def synthesis(melspec, model_nv, nv_config, savepath, device):
+ ## Parallel WaveGAN / MelGAN
+ melspec = torch.tensor(melspec, dtype=torch.float).to(device)
+ #start = time.time()
+ x = model_nv.inference(melspec).view(-1)
+ #elapsed_time = time.time() - start
+ #rtf2 = elapsed_time/audio_len
+ #print ("elapsed_time (waveform generation): {0}".format(elapsed_time) + "[sec]")
+ #print ("real time factor (waveform generation): {0}".format(rtf2))
+
+ # save as PCM 16 bit wav file
+ if not os.path.exists(os.path.dirname(savepath)):
+ os.makedirs(os.path.dirname(savepath))
+ sf.write(savepath, x.detach().cpu().clone().numpy(), nv_config["sampling_rate"], "PCM_16")
+
+def main():
+ parser = argparse.ArgumentParser(description='Testing StarGAN-VC')
+ parser.add_argument('--gpu', '-g', type=int, default=-1, help='GPU ID (negative value indicates CPU)')
+ parser.add_argument('-i', '--input', type=str, default='/misc/raid58/kameoka.hirokazu/python/db/arctic/wav/test',
+ help='root data folder that contains the wav files of input speech')
+ parser.add_argument('-o', '--out', type=str, default='./out/arctic',
+ help='root data folder where the wav files of the converted speech will be saved.')
+ parser.add_argument('--dataconf', type=str, default='./dump/arctic/data_config.json')
+ parser.add_argument('--stat', type=str, default='./dump/arctic/stat.pkl', help='state file used for normalization')
+ parser.add_argument('--model_rootdir', '-mdir', type=str, default='./model/arctic/', help='model file directory')
+ parser.add_argument('--checkpoint', '-ckpt', type=int, default=0, help='model checkpoint to load (0 indicates the newest model)')
+ parser.add_argument('--experiment_name', '-exp', default='experiment1', type=str, help='experiment name')
+ parser.add_argument('--vocoder', '-voc', default='hifigan.v1', type=str,
+ help='neural vocoder type name (e.g., hifigan.v1, hifigan.v2, parallel_wavegan.v1)')
+ parser.add_argument('--voc_dir', '-vdir', type=str, default='pwg/egs/arctic_4spk_flen64ms_fshift8ms/voc1',
+ help='directory of trained neural vocoder')
+ args = parser.parse_args()
+
+ # Set up GPU
+ if torch.cuda.is_available() and args.gpu >= 0:
+ device = torch.device('cuda:%d' % args.gpu)
+ else:
+ device = torch.device('cpu')
+ if device.type == 'cuda':
+ torch.cuda.set_device(device)
+
+ input_dir = args.input
+ data_config_path = args.dataconf
+ model_config_path = os.path.join(args.model_rootdir,args.experiment_name,'model_config.json')
+ with open(data_config_path) as f:
+ data_config = json.load(f)
+ with open(model_config_path) as f:
+ model_config = json.load(f)
+ checkpoint = args.checkpoint
+
+ num_mels = model_config['num_mels']
+ arch_type = model_config['arch_type']
+ loss_type = model_config['loss_type']
+ n_spk = model_config['n_spk']
+ trg_spk_list = model_config['spk_list']
+ zdim = model_config['zdim']
+ hdim = model_config['hdim']
+ mdim = model_config['mdim']
+ sdim = model_config['sdim']
+ normtype = model_config['normtype']
+ src_conditioning = model_config['src_conditioning']
+
+ stat_filepath = args.stat
+ melspec_scaler = StandardScaler()
+ if os.path.exists(stat_filepath):
+ with open(stat_filepath, mode='rb') as f:
+ melspec_scaler = pickle.load(f)
+ print('Loaded mel-spectrogram statistics successfully.')
+ else:
+ print('Stat file not found.')
+
+ # Set up main model
+ gen = net.Generator1(num_mels, n_spk, zdim, hdim, sdim, normtype, src_conditioning) if arch_type=='conv' else net.Generator1(num_mels, n_spk, zdim, hdim, sdim, normtype, src_conditioning)
+ dis = net.Discriminator1(num_mels, n_spk, mdim, normtype) if arch_type=='conv' else net.Discriminator1(num_mels, n_spk, mdim, normtype)
+ models = {
+ 'gen': gen,
+ 'dis': dis
+ }
+ models['stargan'] = net.StarGAN(models['gen'],models['dis'],n_spk,loss_type)
+
+ for tag in ['gen', 'dis']:
+ model_dir = os.path.join(args.model_rootdir,args.experiment_name)
+ vc_checkpoint_idx = find_newest_model_file(model_dir, tag) if checkpoint <= 0 else checkpoint
+ mfilename = '{}.{}.pt'.format(vc_checkpoint_idx,tag)
+ path = os.path.join(args.model_rootdir,args.experiment_name,mfilename)
+ if path is not None:
+ model_checkpoint = torch.load(path, map_location=device)
+ models[tag].load_state_dict(model_checkpoint['model_state_dict'])
+ print('{}: {}'.format(tag, os.path.abspath(path)))
+
+ for tag in ['gen', 'dis']:
+ #models[tag].to(device).eval()
+ models[tag].to(device).train(mode=True)
+
+ # Set up nv
+ vocoder = args.vocoder
+ voc_dir = args.voc_dir
+ voc_yaml_path = os.path.join(voc_dir,'conf', '{}.yaml'.format(vocoder))
+ checkpointlist = listdir_ext(
+ os.path.join(voc_dir,'exp','train_nodev_all_{}'.format(vocoder)),'.pkl')
+ nv_checkpoint = os.path.join(voc_dir,'exp',
+ 'train_nodev_all_{}'.format(vocoder),
+ checkpointlist[-1]) # Find and use the newest checkpoint model.
+ print('vocoder: {}'.format(os.path.abspath(nv_checkpoint)))
+
+ with open(voc_yaml_path) as f:
+ nv_config = yaml.load(f, Loader=yaml.Loader)
+ nv_config.update(vars(args))
+ model_nv = load_model(nv_checkpoint, nv_config)
+ model_nv.remove_weight_norm()
+ model_nv = model_nv.eval().to(device)
+
+ src_spk_list = sorted(os.listdir(input_dir))
+
+ for i, src_spk in enumerate(src_spk_list):
+ src_wav_dir = os.path.join(input_dir, src_spk)
+ for j, trg_spk in enumerate(trg_spk_list):
+ if src_spk != trg_spk:
+ print('Converting {}2{}...'.format(src_spk, trg_spk))
+ for n, src_wav_filename in enumerate(os.listdir(src_wav_dir)):
+ src_wav_filepath = os.path.join(src_wav_dir, src_wav_filename)
+ src_melspec = audio_transform(src_wav_filepath, melspec_scaler, data_config, device)
+ k_t = j
+ k_s = i if src_conditioning else None
+
+ conv_melspec = models['stargan'](src_melspec, k_t, k_s)
+
+ conv_melspec = conv_melspec[0,:,:].detach().cpu().clone().numpy()
+ conv_melspec = conv_melspec.T # n_frames x n_mels
+
+ out_wavpath = os.path.join(args.out,args.experiment_name,'{}'.format(vc_checkpoint_idx),vocoder,'{}2{}'.format(src_spk,trg_spk), src_wav_filename)
+ synthesis(conv_melspec, model_nv, nv_config, out_wavpath, device)
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 00000000..471c89d5
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,65 @@
+# Copyright 2021 Hirokazu Kameoka
+
+import os
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+import h5py
+import math
+import random
+
+def walk_files(root, extension):
+ for path, dirs, files in os.walk(root):
+ for file in files:
+ if file.endswith(extension):
+ yield os.path.join(path, file)
+
+class MultiDomain_Dataset(Dataset):
+ def __init__(self, *feat_dirs):
+ self.n_domain = len(feat_dirs)
+ self.filenames_all = [[os.path.join(d,t) for t in sorted(os.listdir(d))] for d in feat_dirs]
+ #self.filenames_all = [[t for t in walk_files(d, '.h5')] for d in feat_dirs]
+ self.feat_dirs = feat_dirs
+
+ def __len__(self):
+ return min(len(f) for f in self.filenames_all)
+
+ def __getitem__(self, idx):
+ melspec_list = []
+ for d in range(self.n_domain):
+ with h5py.File(self.filenames_all[d][idx], "r") as f:
+ melspec = f["melspec"][()] # n_freq x n_time
+ melspec_list.append(melspec)
+ return melspec_list
+
+def collate_fn(batch):
+ #batch[b][s]: melspec (n_freq x n_frame)
+ #b: batch size
+ #s: speaker ID
+
+ batchsize = len(batch)
+ n_spk = len(batch[0])
+ melspec_list = [[batch[b][s] for b in range(batchsize)] for s in range(n_spk)]
+ #melspec_list[s][b]: melspec (n_freq x n_frame)
+ #s: speaker ID
+ #b: batch size
+
+ n_freq = melspec_list[0][0].shape[0]
+
+ X_list = []
+ for s in range(n_spk):
+ maxlen=0
+ for b in range(batchsize):
+ if maxlen= 0:
+ device = torch.device('cuda:%d' % args.gpu)
+ else:
+ device = torch.device('cpu')
+ if device.type == 'cuda':
+ torch.cuda.set_device(device)
+
+ input_dir = args.input
+ data_config_path = args.dataconf
+ model_config_path = os.path.join(args.model_rootdir,args.experiment_name,'model_config.json')
+ with open(data_config_path) as f:
+ data_config = json.load(f)
+ with open(model_config_path) as f:
+ model_config = json.load(f)
+ checkpoint = args.checkpoint
+
+ num_mels = model_config['num_mels']
+ arch_type = model_config['arch_type']
+ loss_type = model_config['loss_type']
+ n_spk = model_config['n_spk']
+ trg_spk_list = model_config['spk_list']
+ zdim = model_config['zdim']
+ hdim = model_config['hdim']
+ mdim = model_config['mdim']
+ sdim = model_config['sdim']
+ normtype = model_config['normtype']
+ src_conditioning = model_config['src_conditioning']
+
+ stat_filepath = args.stat
+ melspec_scaler = StandardScaler()
+ if os.path.exists(stat_filepath):
+ with open(stat_filepath, mode='rb') as f:
+ melspec_scaler = pickle.load(f)
+ print('Loaded mel-spectrogram statistics successfully.')
+ else:
+ print('Stat file not found.')
+
+ # Set up main model
+ gen = net.Generator1(num_mels, n_spk, zdim, hdim, sdim, normtype, src_conditioning) if arch_type=='conv' else net.Generator1(num_mels, n_spk, zdim, hdim, sdim, normtype, src_conditioning)
+ dis = net.Discriminator1(num_mels, n_spk, mdim, normtype) if arch_type=='conv' else net.Discriminator1(num_mels, n_spk, mdim, normtype)
+ models = {
+ 'gen': gen,
+ 'dis': dis
+ }
+ models['stargan'] = net.StarGAN(models['gen'],models['dis'],n_spk,loss_type)
+
+ for tag in ['gen', 'dis']:
+ model_dir = os.path.join(args.model_rootdir,args.experiment_name)
+ vc_checkpoint_idx = find_newest_model_file(model_dir, tag) if checkpoint <= 0 else checkpoint
+ mfilename = '{}.{}.pt'.format(vc_checkpoint_idx,tag)
+ path = os.path.join(args.model_rootdir,args.experiment_name,mfilename)
+ if path is not None:
+ model_checkpoint = torch.load(path, map_location=device)
+ models[tag].load_state_dict(model_checkpoint['model_state_dict'])
+ print('{}: {}'.format(tag, os.path.abspath(path)))
+
+ for tag in ['gen', 'dis']:
+ #models[tag].to(device).eval()
+ models[tag].to(device).train(mode=True)
+
+ # Set up nv
+ vocoder = args.vocoder
+ voc_dir = args.voc_dir
+ voc_yaml_path = os.path.join(voc_dir,'conf', '{}.yaml'.format(vocoder))
+ checkpointlist = listdir_ext(
+ os.path.join(voc_dir,'exp','train_nodev_all_{}'.format(vocoder)),'.pkl')
+ nv_checkpoint = os.path.join(voc_dir,'exp',
+ 'train_nodev_all_{}'.format(vocoder),
+ checkpointlist[-1]) # Find and use the newest checkpoint model.
+ print('vocoder: {}'.format(os.path.abspath(nv_checkpoint)))
+
+ with open(voc_yaml_path) as f:
+ nv_config = yaml.load(f, Loader=yaml.Loader)
+ nv_config.update(vars(args))
+ model_nv = load_model(nv_checkpoint, nv_config)
+ model_nv.remove_weight_norm()
+ model_nv = model_nv.eval().to(device)
+
+ src_spk_list = sorted(os.listdir(input_dir))
+
+ for i, src_spk in enumerate(src_spk_list):
+ src_wav_dir = os.path.join(input_dir, src_spk)
+ for j, trg_spk in enumerate(trg_spk_list):
+ if src_spk != trg_spk:
+ print('Converting {}2{}...'.format(src_spk, trg_spk))
+ for n, src_wav_filename in enumerate(os.listdir(src_wav_dir)):
+ src_wav_filepath = os.path.join(src_wav_dir, src_wav_filename)
+ src_melspec = audio_transform(src_wav_filepath, melspec_scaler, data_config, device)
+ k_t = j
+ k_s = i if src_conditioning else None
+
+ conv_melspec = models['stargan'](src_melspec, k_t, k_s)
+
+ conv_melspec = conv_melspec[0,:,:].detach().cpu().clone().numpy()
+ conv_melspec = conv_melspec.T # n_frames x n_mels
+
+ out_wavpath = os.path.join(args.out,args.experiment_name,'{}'.format(vc_checkpoint_idx),vocoder,'{}2{}'.format(src_spk,trg_spk), src_wav_filename)
+ synthesis(conv_melspec, model_nv, nv_config, out_wavpath, device)
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/models/stargan/dataset.py b/models/stargan/dataset.py
new file mode 100644
index 00000000..471c89d5
--- /dev/null
+++ b/models/stargan/dataset.py
@@ -0,0 +1,65 @@
+# Copyright 2021 Hirokazu Kameoka
+
+import os
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+import h5py
+import math
+import random
+
+def walk_files(root, extension):
+ for path, dirs, files in os.walk(root):
+ for file in files:
+ if file.endswith(extension):
+ yield os.path.join(path, file)
+
+class MultiDomain_Dataset(Dataset):
+ def __init__(self, *feat_dirs):
+ self.n_domain = len(feat_dirs)
+ self.filenames_all = [[os.path.join(d,t) for t in sorted(os.listdir(d))] for d in feat_dirs]
+ #self.filenames_all = [[t for t in walk_files(d, '.h5')] for d in feat_dirs]
+ self.feat_dirs = feat_dirs
+
+ def __len__(self):
+ return min(len(f) for f in self.filenames_all)
+
+ def __getitem__(self, idx):
+ melspec_list = []
+ for d in range(self.n_domain):
+ with h5py.File(self.filenames_all[d][idx], "r") as f:
+ melspec = f["melspec"][()] # n_freq x n_time
+ melspec_list.append(melspec)
+ return melspec_list
+
+def collate_fn(batch):
+ #batch[b][s]: melspec (n_freq x n_frame)
+ #b: batch size
+ #s: speaker ID
+
+ batchsize = len(batch)
+ n_spk = len(batch[0])
+ melspec_list = [[batch[b][s] for b in range(batchsize)] for s in range(n_spk)]
+ #melspec_list[s][b]: melspec (n_freq x n_frame)
+ #s: speaker ID
+ #b: batch size
+
+ n_freq = melspec_list[0][0].shape[0]
+
+ X_list = []
+ for s in range(n_spk):
+ maxlen=0
+ for b in range(batchsize):
+ if maxlen n_frame_:
+ x = nn.ReplicationPad1d((0, n_frame-n_frame_))(x)
+ return self.gen(x, k_t, k_s)[:,:,0:n_frame_]
+
+ def calc_advloss_g(self, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts):
+ df_adv_ss = df_adv_ss.permute(0,2,1).reshape(-1,1)
+ df_adv_st = df_adv_st.permute(0,2,1).reshape(-1,1)
+ df_adv_tt = df_adv_tt.permute(0,2,1).reshape(-1,1)
+ df_adv_ts = df_adv_ts.permute(0,2,1).reshape(-1,1)
+
+ if self.loss_type=='wgan':
+ # Wasserstein GAN with gradient penalty (WGAN-GP)
+ AdvLoss_g = (
+ torch.sum(-df_adv_ss) +
+ torch.sum(-df_adv_st) +
+ torch.sum(-df_adv_tt) +
+ torch.sum(-df_adv_ts)
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+
+ elif self.loss_type=='lsgan':
+ # Least squares GAN (LSGAN)
+ AdvLoss_g = 0.5 * (
+ torch.sum((df_adv_ss - torch.ones_like(df_adv_ss))**2) +
+ torch.sum((df_adv_st - torch.ones_like(df_adv_st))**2) +
+ torch.sum((df_adv_tt - torch.ones_like(df_adv_tt))**2) +
+ torch.sum((df_adv_ts - torch.ones_like(df_adv_ts))**2)
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+
+ elif self.loss_type=='cgan':
+ # Regular GAN with the sigmoid cross-entropy criterion (CGAN)
+ AdvLoss_g = (
+ F.binary_cross_entropy_with_logits(df_adv_ss, torch.ones_like(df_adv_ss), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_st, torch.ones_like(df_adv_st), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_tt, torch.ones_like(df_adv_tt), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_ts, torch.ones_like(df_adv_ts), reduction='sum')
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+
+ return AdvLoss_g
+
+ def calc_clsloss_g(self, df_cls_ss, df_cls_st, df_cls_tt, df_cls_ts, k_s, k_t):
+ device = df_cls_ss.device
+
+ df_cls_ss = df_cls_ss.permute(0,2,1).reshape(-1,self.n_spk)
+ df_cls_st = df_cls_st.permute(0,2,1).reshape(-1,self.n_spk)
+ df_cls_tt = df_cls_tt.permute(0,2,1).reshape(-1,self.n_spk)
+ df_cls_ts = df_cls_ts.permute(0,2,1).reshape(-1,self.n_spk)
+
+ cf_ss = k_s*torch.ones(len(df_cls_ss), device=device, dtype=torch.long)
+ cf_st = k_t*torch.ones(len(df_cls_st), device=device, dtype=torch.long)
+ cf_tt = k_t*torch.ones(len(df_cls_tt), device=device, dtype=torch.long)
+ cf_ts = k_s*torch.ones(len(df_cls_ts), device=device, dtype=torch.long)
+
+ ClsLoss_g = (
+ F.cross_entropy(df_cls_ss, cf_ss, reduction='sum') +
+ F.cross_entropy(df_cls_st, cf_st, reduction='sum') +
+ F.cross_entropy(df_cls_tt, cf_tt, reduction='sum') +
+ F.cross_entropy(df_cls_ts, cf_ts, reduction='sum')
+ ) / (df_cls_ss.numel() + df_cls_st.numel() + df_cls_tt.numel() + df_cls_ts.numel())
+
+ return ClsLoss_g
+
+ def calc_advloss_d(self, x_s, x_t, xf_ts, xf_st, dr_adv_s, dr_adv_t, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts):
+ device = x_s.device
+ B_s = len(x_s)
+ B_t = len(x_t)
+
+ dr_adv_s = dr_adv_s.permute(0,2,1).reshape(-1,1)
+ dr_adv_t = dr_adv_t.permute(0,2,1).reshape(-1,1)
+ df_adv_ss = df_adv_ss.permute(0,2,1).reshape(-1,1)
+ df_adv_st = df_adv_st.permute(0,2,1).reshape(-1,1)
+ df_adv_tt = df_adv_tt.permute(0,2,1).reshape(-1,1)
+ df_adv_ts = df_adv_ts.permute(0,2,1).reshape(-1,1)
+
+ if self.loss_type=='wgan':
+ # Wasserstein GAN with gradient penalty (WGAN-GP)
+ AdvLoss_d_r = (
+ torch.sum(-dr_adv_s) +
+ torch.sum(-dr_adv_t)
+ ) / (dr_adv_s.numel() + dr_adv_t.numel())
+ AdvLoss_d_f = (
+ torch.sum(df_adv_ss) +
+ torch.sum(df_adv_st) +
+ torch.sum(df_adv_tt) +
+ torch.sum(df_adv_ts)
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+ AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f
+
+ elif self.loss_type=='lsgan':
+ # Least squares GAN (LSGAN)
+
+ AdvLoss_d_r = 0.5 * (
+ torch.sum((dr_adv_s - torch.ones_like(dr_adv_s))**2) +
+ torch.sum((dr_adv_t - torch.ones_like(dr_adv_t))**2)
+ ) / (dr_adv_s.numel() + dr_adv_t.numel())
+ AdvLoss_d_f = 0.5 * (
+ torch.sum(df_adv_ss**2) +
+ torch.sum(df_adv_st**2) +
+ torch.sum(df_adv_tt**2) +
+ torch.sum(df_adv_ts**2)
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+ AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f
+
+ elif self.loss_type=='cgan':
+ # Regular GAN with sigmoid cross-entropy criterion (CGAN)
+ AdvLoss_d_r = (
+ F.binary_cross_entropy_with_logits(dr_adv_s, torch.ones_like(dr_adv_s), reduction='sum') +
+ F.binary_cross_entropy_with_logits(dr_adv_t, torch.ones_like(dr_adv_t), reduction='sum')
+ ) / (dr_adv_s.numel() + dr_adv_t.numel())
+ AdvLoss_d_f = (
+ F.binary_cross_entropy_with_logits(df_adv_ss, torch.zeros_like(df_adv_ss), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_st, torch.zeros_like(df_adv_st), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_tt, torch.zeros_like(df_adv_tt), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_ts, torch.zeros_like(df_adv_ts), reduction='sum')
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+ AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f
+
+ # Gradient penalty loss
+ alpha_t = torch.rand(B_t, 1, 1, requires_grad=True).to(device)
+ interpolates = alpha_t * x_t + ((1 - alpha_t) * xf_ts)
+ interpolates = interpolates.to(device)
+ disc_interpolates, _ = self.dis(interpolates)
+ disc_interpolates = torch.sum(disc_interpolates)
+ gradients = torch.autograd.grad(outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+ create_graph=True, retain_graph=True, only_inputs=True)[0]
+ gradnorm = torch.sqrt(torch.sum(gradients * gradients, (1, 2)))
+ loss_gp_t = ((gradnorm - 1)**2).mean()
+
+ alpha_s = torch.rand(B_s, 1, 1, requires_grad=True).to(device)
+ interpolates = alpha_s * x_s + ((1 - alpha_s) * xf_st)
+ interpolates = interpolates.to(device)
+ interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
+ disc_interpolates, _ = self.dis(interpolates)
+ disc_interpolates = torch.sum(disc_interpolates)
+ gradients = torch.autograd.grad(outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+ create_graph=True, retain_graph=True, only_inputs=True)[0]
+ gradnorm = torch.sqrt(torch.sum(gradients * gradients, (1, 2)))
+ loss_gp_s = ((gradnorm - 1)**2).mean()
+
+ GradLoss_d = loss_gp_s + loss_gp_t
+
+ return AdvLoss_d, GradLoss_d
+
+ def calc_clsloss_d(self, dr_cls_s, dr_cls_t, k_s, k_t):
+ device = dr_cls_s.device
+
+ dr_cls_s = dr_cls_s.permute(0,2,1).reshape(-1,self.n_spk)
+ dr_cls_t = dr_cls_t.permute(0,2,1).reshape(-1,self.n_spk)
+
+ cr_s = k_s*torch.ones(len(dr_cls_s), device=device, dtype=torch.long)
+ cr_t = k_t*torch.ones(len(dr_cls_t), device=device, dtype=torch.long)
+
+ ClsLoss_d = (
+ F.cross_entropy(dr_cls_s, cr_s, reduction='sum') +
+ F.cross_entropy(dr_cls_t, cr_t, reduction='sum')
+ ) / (dr_cls_s.numel() + dr_cls_t.numel())
+
+ return ClsLoss_d
+
+ def calc_gen_loss(self, x_s, x_t, k_s, k_t):
+ # Generator outputs
+ xf_ss = self.gen(x_s, k_s, k_s)
+ xf_ts = self.gen(x_t, k_s, k_t)
+ xf_tt = self.gen(x_t, k_t, k_t)
+ xf_st = self.gen(x_s, k_t, k_s)
+
+ # Discriminator outputs
+ df_adv_ss, df_cls_ss = self.dis(xf_ss)
+ df_adv_st, df_cls_st = self.dis(xf_st)
+ df_adv_tt, df_cls_tt = self.dis(xf_tt)
+ df_adv_ts, df_cls_ts = self.dis(xf_ts)
+
+ # Adversarial loss
+ AdvLoss_g = self.calc_advloss_g(df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts)
+
+ # Classifier loss
+ ClsLoss_g = self.calc_clsloss_g(df_cls_ss, df_cls_st, df_cls_tt, df_cls_ts, k_s, k_t)
+
+ # Cycle-consistency loss
+ CycLoss = (
+ torch.sum(torch.abs(x_s - self.gen(xf_st, k_s, k_t))) +
+ torch.sum(torch.abs(x_t - self.gen(xf_ts, k_t, k_s)))
+ ) / (x_s.numel() + x_t.numel())
+
+ # Reconstruction loss
+ RecLoss = (
+ torch.sum(torch.abs(x_s - xf_ss)) +
+ torch.sum(torch.abs(x_t - xf_tt))
+ ) / (x_s.numel() + x_t.numel())
+
+ return AdvLoss_g, ClsLoss_g, CycLoss, RecLoss
+
+ def calc_dis_loss(self, x_s, x_t, k_s, k_t):
+ device = x_s.device
+
+ # Generator outputs
+ xf_ss = self.gen(x_s, k_s, k_s)
+ xf_ts = self.gen(x_t, k_s, k_t)
+ xf_tt = self.gen(x_t, k_t, k_t)
+ xf_st = self.gen(x_s, k_t, k_s)
+
+ # Discriminator outputs
+ dr_adv_s, dr_cls_s = self.dis(x_s)
+ dr_adv_t, dr_cls_t = self.dis(x_t)
+ df_adv_ss, _ = self.dis(xf_ss)
+ df_adv_st, _ = self.dis(xf_st)
+ df_adv_tt, _ = self.dis(xf_tt)
+ df_adv_ts, _ = self.dis(xf_ts)
+
+ # Adversarial loss
+ AdvLoss_d, GradLoss_d = self.calc_advloss_d(x_s, x_t, xf_ts, xf_st, dr_adv_s, dr_adv_t, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts)
+
+ # Classifier loss
+ ClsLoss_d = self.calc_clsloss_d(dr_cls_s, dr_cls_t, k_s, k_t)
+
+ return AdvLoss_d, GradLoss_d, ClsLoss_d
diff --git a/models/stargan/normalize_features.py b/models/stargan/normalize_features.py
new file mode 100644
index 00000000..0b4aa157
--- /dev/null
+++ b/models/stargan/normalize_features.py
@@ -0,0 +1,94 @@
+import argparse
+import joblib
+import logging
+import os
+
+import h5py
+import numpy as np
+from tqdm import tqdm
+from sklearn.preprocessing import StandardScaler
+import pickle
+
+def walk_files(root, extension):
+ for path, dirs, files in os.walk(root):
+ for file in files:
+ if file.endswith(extension):
+ yield os.path.join(path, file)
+
+def melspec_transform(melspec, scaler):
+ # melspec.shape: (n_freq, n_time)
+ # scaler.transform assumes the first axis to be the time axis
+ melspec = scaler.transform(melspec.T)
+ #import pdb;pdb.set_trace() # Breakpoint
+ melspec = melspec.T
+ return melspec
+
+def normalize_features(src_filepath, dst_filepath, melspec_transform):
+ try:
+ with h5py.File(src_filepath, "r") as f:
+ melspec = f["melspec"][()]
+ melspec = melspec_transform(melspec)
+
+ if not os.path.exists(os.path.dirname(dst_filepath)):
+ os.makedirs(os.path.dirname(dst_filepath), exist_ok=True)
+ with h5py.File(dst_filepath, "w") as f:
+ f.create_dataset("melspec", data=melspec)
+
+ #logging.info(f"{dst_filepath}...[{melspec.shape}].")
+ return melspec.shape
+
+ except:
+ logging.info(f"{dst_filepath}...failed.")
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--src', type=str,
+ default='./dump/arctic/feat/train',
+ help='data folder that contains the raw features extracted from VoxCeleb2 Dataset')
+ parser.add_argument('--dst', type=str, default='./dump/arctic/norm_feat/train',
+ help='data folder where the normalized features are stored')
+ parser.add_argument('--stat', type=str, default='./dump/arctic/stat.pkl',
+ help='state file used for normalization')
+ parser.add_argument('--ext', type=str, default='.h5')
+ args = parser.parse_args()
+
+ src = args.src
+ dst = args.dst
+ ext = args.ext
+ stat_filepath = args.stat
+
+ fmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
+ datafmt = '%m/%d/%Y %I:%M:%S'
+ logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datafmt)
+
+ melspec_scaler = StandardScaler()
+ if os.path.exists(stat_filepath):
+ with open(stat_filepath, mode='rb') as f:
+ melspec_scaler = pickle.load(f)
+ print('Loaded mel-spectrogram statistics successfully.')
+ else:
+ print('Stat file not found.')
+
+ root = src
+ fargs_list = [
+ [
+ f,
+ f.replace(src, dst),
+ lambda x: melspec_transform(x, melspec_scaler),
+ ]
+ for f in walk_files(root, ext)
+ ]
+
+ #import pdb;pdb.set_trace() # Breakpoint
+ # debug
+ #normalize_features(*fargs_list[0])
+ # test
+ #results = joblib.Parallel(n_jobs=-1)(
+ # joblib.delayed(normalize_features)(*f) for f in tqdm(fargs_list)
+ #)
+ results = joblib.Parallel(n_jobs=16)(
+ joblib.delayed(normalize_features)(*f) for f in tqdm(fargs_list)
+ )
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/models/stargan/train_stargan_model.py b/models/stargan/train_stargan_model.py
new file mode 100644
index 00000000..d6e7dc05
--- /dev/null
+++ b/models/stargan/train_stargan_model.py
@@ -0,0 +1,467 @@
+import itertools
+import torch
+from torch import optim
+from torch.utils.tensorboard import SummaryWriter
+from torch.utils.data import DataLoader
+import argparse
+import joblib
+import logging
+import os
+import warnings
+import json
+
+import librosa
+import soundfile as sf
+import h5py
+import numpy as np
+from tqdm import tqdm
+from sklearn.preprocessing import StandardScaler
+import pickle
+from dataset import MultiDomain_Dataset, collate_fn
+import net
+
+def walk_files(root, extension):
+ for path, dirs, files in os.walk(root):
+ for file in files:
+ if file.endswith(extension):
+ yield os.path.join(path, file)
+
+def logmelfilterbank(audio,
+ sampling_rate,
+ fft_size=1024,
+ hop_size=256,
+ win_length=None,
+ window="hann",
+ num_mels=80,
+ fmin=None,
+ fmax=None,
+ eps=1e-10,
+ ):
+ x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
+ win_length=win_length, window=window, pad_mode="reflect")
+ spc = np.abs(x_stft).T # (#frames, #bins)
+
+ # get mel basis
+ fmin = 0 if fmin is None else fmin
+ fmax = sampling_rate / 2 if fmax is None else fmax
+ # mel_basis = librosa.filters.mel(sampling_rate, fft_size, num_mels, fmin, fmax)
+ mel_basis = librosa.filters.mel(sr=sampling_rate, n_fft=fft_size, n_mels=num_mels, fmin=fmin, fmax=fmax)
+
+ return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
+
+def extract_melspec(src_filepath, dst_filepath, kwargs):
+ # print('what')
+ # try:
+ warnings.filterwarnings('ignore')
+
+ trim_silence = kwargs['trim_silence']
+ top_db = kwargs['top_db']
+ flen = kwargs['flen']
+ fshift = kwargs['fshift']
+ fmin = kwargs['fmin']
+ fmax = kwargs['fmax']
+ num_mels = kwargs['num_mels']
+ fs = kwargs['fs']
+
+ audio, fs_ = sf.read(src_filepath)
+ if trim_silence:
+ # print('trimming.')
+ audio, _ = librosa.effects.trim(audio, top_db=top_db, frame_length=2048, hop_length=512)
+ # print('xzz')
+ if fs != fs_:
+ # print('resampling.')
+ # audio = librosa.resample(audio, fs_, fs)
+ audio = librosa.resample(y=audio, orig_sr=fs_, target_sr=fs)
+
+ melspec_raw = logmelfilterbank(audio,fs, fft_size=flen,hop_size=fshift,
+ fmin=fmin, fmax=fmax, num_mels=num_mels)
+ melspec_raw = melspec_raw.astype(np.float32)
+ melspec_raw = melspec_raw.T # n_mels x n_frame
+ if not os.path.exists(os.path.dirname(dst_filepath)):
+ os.makedirs(os.path.dirname(dst_filepath), exist_ok=True)
+ with h5py.File(dst_filepath, "w") as f:
+ f.create_dataset("melspec", data=melspec_raw)
+ logging.info(f"{dst_filepath}...[{melspec_raw.shape}].")
+
+
+def makedirs_if_not_exists(dir):
+ if not os.path.exists(dir):
+ os.makedirs(dir)
+
+
+def comb(N, r):
+ iterable = list(range(0, N))
+ return list(itertools.combinations(iterable, 2))
+
+
+src = './dataset/vctk/data'
+dst = './models/stargan/dump/arctic/feat/train'
+ext = '.wav'
+
+fmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
+datafmt = '%m/%d/%Y %I:%M:%S'
+logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datafmt)
+
+data_config = {
+ 'num_mels': 80,
+ 'fs': 16000,
+ 'flen': 1024,
+ 'fshift': 128,
+ 'fmin': 80,
+ 'fmax': 7600,
+ 'trim_silence': True,
+ 'top_db': 30
+}
+configpath = './models/stargan/dump/arctic/data_config.json'
+if not os.path.exists(os.path.dirname(configpath)):
+ os.makedirs(os.path.dirname(configpath))
+with open(configpath, 'w') as outfile:
+ json.dump(data_config, outfile, indent=4)
+
+fargs_list = [
+ [
+ f,
+ f.replace(src, dst).replace(ext, ".h5"),
+ data_config,
+ ]
+ for f in walk_files(src, ext)
+]
+print(fargs_list)
+results = joblib.Parallel(n_jobs=1)(
+ joblib.delayed(extract_melspec)(*f) for f in tqdm(fargs_list)
+)
+
+
+def walk_files(root, extension):
+ for path, dirs, files in os.walk(root):
+ for file in files:
+ if file.endswith(extension):
+ yield os.path.join(path, file)
+
+
+def melspec_transform(melspec, scaler):
+ melspec = scaler.transform(melspec.T)
+ melspec = melspec.T
+ return melspec
+
+def normalize_features(src_filepath, dst_filepath, melspec_transform):
+ try:
+ with h5py.File(src_filepath, "r") as f:
+ melspec = f["melspec"][()]
+ melspec = melspec_transform(melspec)
+
+ if not os.path.exists(os.path.dirname(dst_filepath)):
+ os.makedirs(os.path.dirname(dst_filepath), exist_ok=True)
+ with h5py.File(dst_filepath, "w") as f:
+ f.create_dataset("melspec", data=melspec)
+
+ # logging.info(f"{dst_filepath}...[{melspec.shape}].")
+ return melspec.shape
+
+ except:
+ logging.info(f"{dst_filepath}...failed.")
+
+# parser.add_argument('--src', type=str,
+# default='./models/stargan/dump/arctic/feat/train',
+# help='data folder that contains the raw features extracted from VoxCeleb2 Dataset')
+# parser.add_argument('--dst', type=str, default='./models/stargan/dump/arctic/norm_feat/train',
+# help='data folder where the normalized features are stored')
+# parser.add_argument('--stat', type=str, default='./models/stargan/dump/arctic/stat.pkl',
+# help='state file used for normalization')
+# parser.add_argument('--ext', type=str, default='.h5')
+
+
+src = './models/stargan/dump/arctic/feat/train'
+dst = './models/stargan/dump/arctic/norm_feat/train'
+ext = '.h5'
+stat_filepath = './models/stargan/dump/arctic/stat.pkl'
+
+fmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
+datafmt = '%m/%d/%Y %I:%M:%S'
+logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datafmt)
+
+melspec_scaler = StandardScaler()
+if os.path.exists(stat_filepath):
+ with open(stat_filepath, mode='rb') as f:
+ melspec_scaler = pickle.load(f)
+ print('Loaded mel-spectrogram statistics successfully.')
+else:
+ print('Stat file not found.')
+
+root = src
+fargs_list = [
+ [
+ f,
+ f.replace(src, dst),
+ lambda x: melspec_transform(x, melspec_scaler),
+ ]
+ for f in walk_files(root, ext)
+]
+print(fargs_list)
+results = joblib.Parallel(n_jobs=1)(
+ joblib.delayed(normalize_features)(*f) for f in tqdm(fargs_list)
+)
+
+
+def walk_files(root, extension):
+ for path, dirs, files in os.walk(root):
+ for file in files:
+ if file.endswith(extension):
+ yield os.path.join(path, file)
+
+def read_melspec(filepath):
+ with h5py.File(filepath, "r") as f:
+ melspec = f["melspec"][()] # n_mels x n_frame
+ # import pdb;pdb.set_trace() # Breakpoint
+ return melspec
+
+
+def compute_statistics(src, stat_filepath):
+ melspec_scaler = StandardScaler()
+
+ filepath_list = list(walk_files(src, '.h5'))
+ for filepath in tqdm(filepath_list):
+ melspec = read_melspec(filepath)
+ # import pdb;pdb.set_trace() # Breakpoint
+ melspec_scaler.partial_fit(melspec.T)
+
+ with open(stat_filepath, mode='wb') as f:
+ pickle.dump(melspec_scaler, f)
+
+
+fmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
+datafmt = '%m/%d/%Y %I:%M:%S'
+logging.basicConfig(level=logging.INFO, format=fmt, datefmt=datafmt)
+
+src = './models/stargan/dump/arctic/feat/train'
+stat_filepath = './models/stargan/dump/arctic/stat.pkl'
+if not os.path.exists(os.path.dirname(stat_filepath)):
+ os.makedirs(os.path.dirname(stat_filepath))
+
+compute_statistics(src, stat_filepath)
+
+gpu = -1
+data_rootdir = './models/stargan/dump/arctic/norm_feat/train'
+epochs = 2000
+snapshot = 200
+batch_size = 12
+num_mels = 80
+arch_type = 'conv'
+loss_type = 'wgan'
+zdim = 16
+hdim = 64
+mdim = 32
+sdim = 16
+lrate_g = 0.0005
+lrate_d = 5e-6
+gradient_clip = 1.0
+w_adv = 1.0
+w_grad = 1.0
+w_cls = 1.0
+w_cyc = 1.0
+w_rec = 1.0
+normtype = 'IN'
+src_conditioning = 0
+resume = 0
+model_rootdir = './model/arctic/'
+log_dir = './logs/arctic/'
+experiment_name = 'experiment1'
+
+
+
+
+def train_stargan():
+ # Set up GPU
+ if torch.cuda.is_available() and gpu >= 0:
+ device = torch.device('cuda:%d' % gpu)
+ else:
+ device = torch.device('cpu')
+ if device.type == 'cuda':
+ torch.cuda.set_device(device)
+
+
+ spk_list = sorted(os.listdir(data_rootdir))
+ n_spk = len(spk_list)
+ melspec_dirs = [os.path.join(data_rootdir, spk) for spk in spk_list]
+
+ config = {
+ 'num_mels': num_mels,
+ 'arch_type': arch_type,
+ 'loss_type': loss_type,
+ 'zdim': zdim,
+ 'hdim': hdim,
+ 'mdim': mdim,
+ 'sdim': sdim,
+ 'w_adv': 1.0,
+ 'w_grad': 1.0,
+ 'w_cls': 1.0,
+ 'w_cyc': 1.0,
+ 'w_rec': 1.0,
+ 'lrate_g': lrate_g,
+ 'lrate_d': lrate_d,
+ 'gradient_clip': 1.0,
+ 'normtype': normtype,
+ 'epochs': epochs,
+ 'BatchSize': batch_size,
+ 'n_spk': n_spk,
+ 'spk_list': spk_list,
+ 'src_conditioning': src_conditioning
+ }
+
+ model_dir = os.path.join(model_rootdir, experiment_name)
+ makedirs_if_not_exists(model_dir)
+ log_path = os.path.join(log_dir, experiment_name, 'train_{}.log'.format(experiment_name))
+
+ # Save configuration as a json file
+ config_path = os.path.join(model_dir, 'model_config.json')
+ with open(config_path, 'w') as outfile:
+ json.dump(config, outfile, indent=4)
+
+ if arch_type == 'conv':
+ gen = net.Generator1(num_mels, n_spk, zdim, hdim, sdim, normtype, src_conditioning)
+ elif arch_type == 'rnn':
+ net.Generator2(num_mels, n_spk, zdim, hdim, sdim, src_conditioning=src_conditioning)
+ dis = net.Discriminator1(num_mels, n_spk, mdim, normtype)
+ models = {
+ 'gen': gen,
+ 'dis': dis
+ }
+ models['stargan'] = net.StarGAN(models['gen'], models['dis'], n_spk, loss_type)
+
+ optimizers = {
+ 'gen': optim.Adam(models['gen'].parameters(), lr=lrate_g, betas=(0.9, 0.999)),
+ 'dis': optim.Adam(models['dis'].parameters(), lr=lrate_d, betas=(0.5, 0.999))
+ }
+
+ for tag in ['gen', 'dis']:
+ models[tag].to(device).train(mode=True)
+
+ train_dataset = MultiDomain_Dataset(*melspec_dirs)
+ train_loader = DataLoader(train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=0,
+ # num_workers=os.cpu_count(),
+ drop_last=True,
+ collate_fn=collate_fn)
+
+ fmt = '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s'
+ datafmt = '%m/%d/%Y %I:%M:%S'
+ if not os.path.exists(os.path.dirname(log_path)):
+ os.makedirs(os.path.dirname(log_path))
+ logging.basicConfig(filename=log_path, filemode='a', level=logging.INFO, format=fmt, datefmt=datafmt)
+ writer = SummaryWriter(os.path.dirname(log_path))
+
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+
+ for tag in ['gen', 'dis']:
+ checkpointpath = os.path.join(model_dir, '{}.{}.pt'.format(resume, tag))
+ if os.path.exists(checkpointpath):
+ checkpoint = torch.load(checkpointpath, map_location=device)
+ models[tag].load_state_dict(checkpoint['model_state_dict'])
+ optimizers[tag].load_state_dict(checkpoint['optimizer_state_dict'])
+ print('{} loaded successfully.'.format(checkpointpath))
+
+ w_adv = config['w_adv']
+ w_grad = config['w_grad']
+ w_cls = config['w_cls']
+ w_cyc = config['w_cyc']
+ w_rec = config['w_rec']
+ gradient_clip = config['gradient_clip']
+
+ print("===================================Training Started===================================")
+ n_iter = 0
+ for epoch in range(resume + 1, epochs + 1):
+ b = 0
+ for X_list in train_loader:
+ n_spk = len(X_list)
+ xin = []
+ for s in range(n_spk):
+ xin.append(torch.tensor(X_list[s]).to(device, dtype=torch.float))
+
+ # List of speaker pairs
+ spk_pair_list = comb(n_spk, 2)
+ n_spk_pair = len(spk_pair_list)
+
+ gen_loss_mean = 0
+ dis_loss_mean = 0
+ advloss_d_mean = 0
+ gradloss_d_mean = 0
+ advloss_g_mean = 0
+ clsloss_d_mean = 0
+ clsloss_g_mean = 0
+ cycloss_mean = 0
+ recloss_mean = 0
+ # Iterate through all speaker pairs
+ for m in range(n_spk_pair):
+ s0 = spk_pair_list[m][0]
+ s1 = spk_pair_list[m][1]
+
+ AdvLoss_g, ClsLoss_g, CycLoss, RecLoss = models['stargan'].calc_gen_loss(xin[s0], xin[s1], s0, s1)
+ gen_loss = (w_adv * AdvLoss_g + w_cls * ClsLoss_g + w_cyc * CycLoss + w_rec * RecLoss)
+
+ models['gen'].zero_grad()
+ gen_loss.backward()
+ torch.nn.utils.clip_grad_norm_(models['gen'].parameters(), gradient_clip)
+ optimizers['gen'].step()
+
+ AdvLoss_d, GradLoss_d, ClsLoss_d = models['stargan'].calc_dis_loss(xin[s0], xin[s1], s0, s1)
+ dis_loss = w_adv * AdvLoss_d + w_grad * GradLoss_d + w_cls * ClsLoss_d
+
+ models['dis'].zero_grad()
+ dis_loss.backward()
+ torch.nn.utils.clip_grad_norm_(models['dis'].parameters(), gradient_clip)
+ optimizers['dis'].step()
+
+ gen_loss_mean += gen_loss.item()
+ dis_loss_mean += dis_loss.item()
+ advloss_d_mean += AdvLoss_d.item()
+ gradloss_d_mean += GradLoss_d.item()
+ advloss_g_mean += AdvLoss_g.item()
+ clsloss_d_mean += ClsLoss_d.item()
+ clsloss_g_mean += ClsLoss_g.item()
+ cycloss_mean += CycLoss.item()
+ recloss_mean += RecLoss.item()
+
+ gen_loss_mean /= n_spk_pair
+ dis_loss_mean /= n_spk_pair
+ advloss_d_mean /= n_spk_pair
+ gradloss_d_mean /= n_spk_pair
+ advloss_g_mean /= n_spk_pair
+ clsloss_d_mean /= n_spk_pair
+ clsloss_g_mean /= n_spk_pair
+ cycloss_mean /= n_spk_pair
+ recloss_mean /= n_spk_pair
+
+ logging.info(
+ 'epoch {}, mini-batch {}: AdvLoss_d={:.4f}, AdvLoss_g={:.4f}, GradLoss_d={:.4f}, ClsLoss_d={:.4f}, ClsLoss_g={:.4f}'
+ .format(epoch, b + 1, w_adv * advloss_d_mean, w_adv * advloss_g_mean, w_grad * gradloss_d_mean,
+ w_cls * clsloss_d_mean, w_cls * clsloss_g_mean))
+ logging.info(
+ 'epoch {}, mini-batch {}: CycLoss={:.4f}, RecLoss={:.4f}'.format(epoch, b + 1, w_cyc * cycloss_mean,
+ w_rec * recloss_mean))
+ writer.add_scalars('Loss/Total_Loss', {'adv_loss_d': w_adv * advloss_d_mean,
+ 'adv_loss_g': w_adv * advloss_g_mean,
+ 'grad_loss_d': w_grad * gradloss_d_mean,
+ 'cls_loss_d': w_cls * clsloss_d_mean,
+ 'cls_loss_g': w_cls * clsloss_g_mean,
+ 'cyc_loss': w_cyc * cycloss_mean,
+ 'rec_loss': w_rec * recloss_mean}, n_iter)
+ n_iter += 1
+ b += 1
+
+ if epoch % snapshot == 0:
+ for tag in ['gen', 'dis']:
+ print('save {} at {} epoch'.format(tag, epoch))
+ torch.save({'epoch': epoch,
+ 'model_state_dict': models[tag].state_dict(),
+ 'optimizer_state_dict': optimizers[tag].state_dict()},
+ os.path.join(model_dir, '{}.{}.pt'.format(epoch, tag)))
+
+ print("===================================Training Finished===================================")
+
+
+train_stargan()
+# Train(models, epochs, train_dataset, train_loader, optimizers, device, model_dir, log_path, model_config, snapshot,
+# resume)
\ No newline at end of file
diff --git a/models/stargan/useless.txt b/models/stargan/useless.txt
new file mode 100644
index 00000000..e69de29b
diff --git a/module.py b/module.py
new file mode 100644
index 00000000..f88b3bee
--- /dev/null
+++ b/module.py
@@ -0,0 +1,136 @@
+# Copyright 2021 Hirokazu Kameoka
+
+import torch
+import torch.nn as nn
+
+def calc_padding(kernel_size, dilation, causal, stride=1):
+ if causal:
+ padding = (kernel_size-1)*dilation+1-stride
+ else:
+ padding = ((kernel_size-1)*dilation+1-stride)//2
+ return padding
+
+class LinearWN(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True):
+ super(LinearWN, self).__init__()
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+ #nn.init.xavier_normal_(self.linear_layer.weight,gain=0.1)
+ self.linear_layer = nn.utils.weight_norm(self.linear_layer)
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+class ConvGLU1D(nn.Module):
+ def __init__(self, in_ch, out_ch, ks, sd, normtype='IN'):
+ super(ConvGLU1D, self).__init__()
+ self.conv1 = nn.Conv1d(
+ in_ch, out_ch*2, ks, stride=sd, padding=(ks-sd)//2)
+ nn.init.xavier_normal_(self.conv1.weight,gain=0.1)
+ if normtype=='BN':
+ self.norm1 = nn.BatchNorm1d(out_ch*2)
+ elif normtype=='IN':
+ self.norm1 = nn.InstanceNorm1d(out_ch*2)
+ elif normtype=='LN':
+ self.norm1 = nn.LayerNorm(out_ch*2)
+
+ self.conv1 = nn.utils.weight_norm(self.conv1)
+ self.normtype = normtype
+
+ def __call__(self, x):
+ h = self.conv1(x)
+ if self.normtype=='BN' or self.normtype=='IN':
+ h = self.norm1(h)
+ elif self.normtype=='LN':
+ B, D, N = h.shape
+ h = h.permute(0,2,1).reshape(-1,D)
+ h = self.norm1(h)
+ h = h.reshape(B,N,D).permute(0,2,1)
+ h_l, h_g = torch.split(h, h.shape[1]//2, dim=1)
+ h = h_l * torch.sigmoid(h_g)
+
+ return h
+
+class DeconvGLU1D(nn.Module):
+ def __init__(self, in_ch, out_ch, ks, sd, normtype='IN'):
+ super(DeconvGLU1D, self).__init__()
+ self.conv1 = nn.ConvTranspose1d(
+ in_ch, out_ch*2, ks, stride=sd, padding=(ks-sd)//2)
+ nn.init.xavier_normal_(self.conv1.weight,gain=0.1)
+ if normtype=='BN':
+ self.norm1 = nn.BatchNorm1d(out_ch*2)
+ elif normtype=='IN':
+ self.norm1 = nn.InstanceNorm1d(out_ch*2)
+ elif normtype=='LN':
+ self.norm1 = nn.LayerNorm(out_ch*2)
+
+ self.conv1 = nn.utils.weight_norm(self.conv1)
+ self.normtype = normtype
+
+ def __call__(self, x):
+ h = self.conv1(x)
+ if self.normtype=='BN' or self.normtype=='IN':
+ h = self.norm1(h)
+ elif self.normtype=='LN':
+ B, D, N = h.shape
+ h = h.permute(0,2,1).reshape(-1,D)
+ h = self.norm1(h)
+ h = h.reshape(B,N,D).permute(0,2,1)
+ h_l, h_g = torch.split(h, h.shape[1]//2, dim=1)
+ h = h_l * torch.sigmoid(h_g)
+
+ return h
+
+class PixelShuffleGLU1D(nn.Module):
+ def __init__(self, in_ch, out_ch, ks, sd, normtype='IN'):
+ super(PixelShuffleGLU1D, self).__init__()
+ self.conv1 = nn.Conv1d(
+ in_ch, out_ch*2*sd, ks, stride=1, padding=(ks-1)//2)
+ self.r = sd
+ if normtype=='BN':
+ self.norm1 = nn.BatchNorm1d(out_ch*2)
+ elif normtype=='IN':
+ self.norm1 = nn.InstanceNorm1d(out_ch*2)
+ self.conv1 = nn.utils.weight_norm(self.conv1)
+ self.normtype = normtype
+
+ def __call__(self, x):
+ h = self.conv1(x)
+ N, pre_ch, pre_len = h.shape
+ r = self.r
+ post_ch = pre_ch//r
+ post_len = pre_len * r
+ h = torch.reshape(h, (N, r, post_ch, pre_len))
+ h = h.permute(0,2,3,1)
+
+ h = torch.reshape(h, (N, post_ch, post_len))
+ if self.normtype=='BN' or self.normtype=='IN':
+ h = self.norm1(h)
+ h_l, h_g = torch.split(h, h.shape[1]//2, dim=1)
+ h = h_l * torch.sigmoid(h_g)
+
+ return h
+
+def concat_dim1(x,y):
+ assert x.shape[0] == y.shape[0]
+ if torch.Tensor.dim(x) == 3:
+ y0 = torch.unsqueeze(y,2)
+ N, n_ch, n_t = x.shape
+ yy = y0.repeat(1,1,n_t)
+ h = torch.cat((x,yy), dim=1)
+ elif torch.Tensor.dim(x) == 4:
+ y0 = torch.unsqueeze(torch.unsqueeze(y,2),3)
+ N, n_ch, n_q, n_t = x.shape
+ yy = y0.repeat(1,1,n_q,n_t)
+ h = torch.cat((x,yy), dim=1)
+ return h
+
+def concat_dim2(x,y):
+ assert x.shape[0] == y.shape[0]
+ if torch.Tensor.dim(x) == 3:
+ y0 = torch.unsqueeze(y,1)
+ N, n_t, n_ch = x.shape
+ yy = y0.repeat(1,n_t,1)
+ h = torch.cat((x,yy), dim=2)
+ elif torch.Tensor.dim(x) == 2:
+ h = torch.cat((x,y), dim=1)
+ return h
\ No newline at end of file
diff --git a/net.py b/net.py
new file mode 100644
index 00000000..60531dea
--- /dev/null
+++ b/net.py
@@ -0,0 +1,396 @@
+import numpy as np
+import six
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+import module as md
+
+class Generator1(nn.Module):
+ # 1D convolutional architecture
+ def __init__(self, in_ch, n_spk, z_ch, mid_ch, s_ch, normtype='IN', src_conditioning=False):
+ super(Generator1, self).__init__()
+ add_ch = 0 if src_conditioning==False else s_ch
+ self.le1 = md.ConvGLU1D(in_ch+add_ch, mid_ch, 9, 1, normtype)
+ self.le2 = md.ConvGLU1D(mid_ch+add_ch, mid_ch, 8, 2, normtype)
+ self.le3 = md.ConvGLU1D(mid_ch+add_ch, mid_ch, 8, 2, normtype)
+ self.le4 = md.ConvGLU1D(mid_ch+add_ch, mid_ch, 5, 1, normtype)
+ self.le5 = md.ConvGLU1D(mid_ch+add_ch, z_ch, 5, 1, normtype)
+ self.le6 = md.DeconvGLU1D(z_ch+s_ch, mid_ch, 5, 1, normtype)
+ self.le7 = md.DeconvGLU1D(mid_ch+s_ch, mid_ch, 5, 1, normtype)
+ self.le8 = md.DeconvGLU1D(mid_ch+s_ch, mid_ch, 8, 2, normtype)
+ self.le9 = md.DeconvGLU1D(mid_ch+s_ch, mid_ch, 8, 2, normtype)
+ self.le10 = nn.Conv1d(mid_ch+s_ch, in_ch, 9, stride=1, padding=(9-1)//2)
+ #nn.init.xavier_normal_(self.le10.weight,gain=0.1)
+
+ if src_conditioning:
+ self.eb0 = nn.Embedding(n_spk, s_ch)
+ self.eb1 = nn.Embedding(n_spk, s_ch)
+ self.src_conditioning = src_conditioning
+
+ def __call__(self, xin, k_t, k_s=None):
+ device = xin.device
+ B, n_mels, n_frame_ = xin.shape
+
+ kk_t = k_t*torch.ones(B).to(device, dtype=torch.int64)
+ trgspk_emb = self.eb1(kk_t)
+ if self.src_conditioning:
+ kk_s = k_s*torch.ones(B).to(device, dtype=torch.int64)
+ srcspk_emb = self.eb0(kk_s)
+
+ out = xin
+
+ if self.src_conditioning: out = md.concat_dim1(out,srcspk_emb)
+ out = self.le1(out)
+ if self.src_conditioning:
+ out = md.concat_dim1(out,srcspk_emb)
+ out = self.le2(out)
+ if self.src_conditioning:
+ out = md.concat_dim1(out,srcspk_emb)
+ out = self.le3(out)
+ if self.src_conditioning:
+ out = md.concat_dim1(out,srcspk_emb)
+ out = self.le4(out)
+ if self.src_conditioning:
+ out = md.concat_dim1(out,srcspk_emb)
+ out = self.le5(out)
+ out = md.concat_dim1(out,trgspk_emb)
+ out = self.le6(out)
+ out = md.concat_dim1(out,trgspk_emb)
+ out = self.le7(out)
+ out = md.concat_dim1(out,trgspk_emb)
+ out = self.le8(out)
+ out = md.concat_dim1(out,trgspk_emb)
+ out = self.le9(out)
+ out = md.concat_dim1(out,trgspk_emb)
+ out = self.le10(out)
+
+ return out
+
+class Generator2(nn.Module):
+ # Bidirectional LSTM
+ def __init__(self, in_ch, n_spk, z_ch, mid_ch, s_ch, num_layers=2, negative_slope=0.1, src_conditioning=False):
+ super(Generator2, self).__init__()
+ add_ch = 0 if src_conditioning==False else s_ch
+
+ self.linear0 = md.LinearWN(in_ch+add_ch, mid_ch)
+ self.lrelu0 = nn.LeakyReLU(negative_slope)
+ self.rnn0 = nn.LSTM(
+ mid_ch+add_ch,
+ mid_ch//2,
+ num_layers,
+ dropout=0,
+ bidirectional=True,
+ batch_first = True
+ )
+ self.linear1 = md.LinearWN(mid_ch+add_ch, z_ch)
+ self.linear2 = md.LinearWN(z_ch+s_ch, mid_ch)
+ self.lrelu1 = nn.LeakyReLU(negative_slope)
+ self.rnn1 = nn.LSTM(
+ mid_ch+s_ch,
+ mid_ch//2,
+ num_layers,
+ dropout=0,
+ bidirectional=True,
+ batch_first = True
+ )
+ self.linear3 = md.LinearWN(mid_ch+s_ch, in_ch)
+
+ if src_conditioning:
+ self.eb0 = nn.Embedding(n_spk, s_ch)
+ self.eb1 = nn.Embedding(n_spk, s_ch)
+ self.src_conditioning = src_conditioning
+
+ def __call__(self, xin, k_t, k_s=None):
+ device = xin.device
+ B, num_mels, num_frame = xin.shape
+ kk_t = k_t*torch.ones(B).to(device, dtype=torch.int64)
+ trgspk_emb = self.eb1(kk_t)
+ if self.src_conditioning:
+ kk_s = k_s*torch.ones(B).to(device, dtype=torch.int64)
+ srcspk_emb = self.eb0(kk_s)
+ out = xin
+
+ out = out.permute(0,2,1) # (B, num_frame, num_mels)
+ if self.src_conditioning: out = md.concat_dim2(out,srcspk_emb) # (B, num_frame, num_mels+add_ch)
+ out = self.lrelu0(self.linear0(out))
+ if self.src_conditioning: out = md.concat_dim2(out,srcspk_emb) # (B, num_frame, mid_ch+add_ch)
+ self.rnn0.flatten_parameters()
+ out, _ = self.rnn0(out) # (B, num_frame, mid_ch)
+ if self.src_conditioning: out = md.concat_dim2(out,srcspk_emb) # (B, num_frame, mid_ch+add_ch)
+ out = self.linear1(out) # (B, num_frame, z_ch)
+ out = md.concat_dim2(out,trgspk_emb) # (B, num_frame, z_ch+s_ch)
+ out = self.lrelu1(self.linear2(out)) # (B, num_frame, mid_ch)
+ out = md.concat_dim2(out,trgspk_emb) # (B, num_frame, mid_ch+s_ch)
+ self.rnn1.flatten_parameters()
+ out, _ = self.rnn1(out) # (B, num_frame, mid_ch)
+ out = md.concat_dim2(out,trgspk_emb) # (B, num_frame, mid_ch+s_ch)
+ out = self.linear3(out) # (B, num_frame, in_ch)
+ out = out.permute(0,2,1) # (B, in_ch, num_frame)
+
+ return out
+
+class Discriminator1(nn.Module):
+ # 1D convolutional architecture
+ def __init__(self, in_ch, clsnum, mid_ch, normtype='IN', dor=0.1):
+ super(Discriminator1, self).__init__()
+ self.le1 = md.ConvGLU1D(in_ch, mid_ch, 9, 1, normtype)
+ self.le2 = md.ConvGLU1D(mid_ch, mid_ch, 8, 2, normtype)
+ self.le3 = md.ConvGLU1D(mid_ch, mid_ch, 8, 2, normtype)
+ self.le4 = md.ConvGLU1D(mid_ch, mid_ch, 5, 1, normtype)
+ self.le_adv = nn.Conv1d(mid_ch, 1, 5, stride=1, padding=(5-1)//2, bias=False)
+ self.le_cls = nn.Conv1d(mid_ch, clsnum, 5, stride=1, padding=(5-1)//2, bias=False)
+ nn.init.xavier_normal_(self.le_adv.weight,gain=0.1)
+ nn.init.xavier_normal_(self.le_cls.weight,gain=0.1)
+ self.do1 = nn.Dropout(p=dor)
+ self.do2 = nn.Dropout(p=dor)
+ self.do3 = nn.Dropout(p=dor)
+ self.do4 = nn.Dropout(p=dor)
+
+ def __call__(self, xin):
+ device = xin.device
+ B, n_mels, n_frame_ = xin.shape
+
+ out = xin
+
+ out = self.do1(self.le1(out))
+ out = self.do2(self.le2(out))
+ out = self.do3(self.le3(out))
+ out = self.do4(self.le4(out))
+
+ out_adv = self.le_adv(out)
+ out_cls = self.le_cls(out)
+
+ return out_adv, out_cls
+
+class StarGAN(nn.Module):
+ def __init__(self, gen, dis, n_spk, loss_type='wgan'):
+ super(StarGAN, self).__init__()
+ self.gen = gen
+ self.dis = dis
+ self.n_spk = n_spk
+ self.loss_type = loss_type
+
+ def forward(self, x, k_t, k_s=None):
+ device = x.device
+ n_frame_ = x.shape[2]
+ n_frame = math.ceil(n_frame_/4)*4
+ if n_frame > n_frame_:
+ x = nn.ReplicationPad1d((0, n_frame-n_frame_))(x)
+ return self.gen(x, k_t, k_s)[:,:,0:n_frame_]
+
+ def calc_advloss_g(self, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts):
+ df_adv_ss = df_adv_ss.permute(0,2,1).reshape(-1,1)
+ df_adv_st = df_adv_st.permute(0,2,1).reshape(-1,1)
+ df_adv_tt = df_adv_tt.permute(0,2,1).reshape(-1,1)
+ df_adv_ts = df_adv_ts.permute(0,2,1).reshape(-1,1)
+
+ if self.loss_type=='wgan':
+ # Wasserstein GAN with gradient penalty (WGAN-GP)
+ AdvLoss_g = (
+ torch.sum(-df_adv_ss) +
+ torch.sum(-df_adv_st) +
+ torch.sum(-df_adv_tt) +
+ torch.sum(-df_adv_ts)
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+
+ elif self.loss_type=='lsgan':
+ # Least squares GAN (LSGAN)
+ AdvLoss_g = 0.5 * (
+ torch.sum((df_adv_ss - torch.ones_like(df_adv_ss))**2) +
+ torch.sum((df_adv_st - torch.ones_like(df_adv_st))**2) +
+ torch.sum((df_adv_tt - torch.ones_like(df_adv_tt))**2) +
+ torch.sum((df_adv_ts - torch.ones_like(df_adv_ts))**2)
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+
+ elif self.loss_type=='cgan':
+ # Regular GAN with the sigmoid cross-entropy criterion (CGAN)
+ AdvLoss_g = (
+ F.binary_cross_entropy_with_logits(df_adv_ss, torch.ones_like(df_adv_ss), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_st, torch.ones_like(df_adv_st), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_tt, torch.ones_like(df_adv_tt), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_ts, torch.ones_like(df_adv_ts), reduction='sum')
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+
+ return AdvLoss_g
+
+ def calc_clsloss_g(self, df_cls_ss, df_cls_st, df_cls_tt, df_cls_ts, k_s, k_t):
+ device = df_cls_ss.device
+
+ df_cls_ss = df_cls_ss.permute(0,2,1).reshape(-1,self.n_spk)
+ df_cls_st = df_cls_st.permute(0,2,1).reshape(-1,self.n_spk)
+ df_cls_tt = df_cls_tt.permute(0,2,1).reshape(-1,self.n_spk)
+ df_cls_ts = df_cls_ts.permute(0,2,1).reshape(-1,self.n_spk)
+
+ cf_ss = k_s*torch.ones(len(df_cls_ss), device=device, dtype=torch.long)
+ cf_st = k_t*torch.ones(len(df_cls_st), device=device, dtype=torch.long)
+ cf_tt = k_t*torch.ones(len(df_cls_tt), device=device, dtype=torch.long)
+ cf_ts = k_s*torch.ones(len(df_cls_ts), device=device, dtype=torch.long)
+
+ ClsLoss_g = (
+ F.cross_entropy(df_cls_ss, cf_ss, reduction='sum') +
+ F.cross_entropy(df_cls_st, cf_st, reduction='sum') +
+ F.cross_entropy(df_cls_tt, cf_tt, reduction='sum') +
+ F.cross_entropy(df_cls_ts, cf_ts, reduction='sum')
+ ) / (df_cls_ss.numel() + df_cls_st.numel() + df_cls_tt.numel() + df_cls_ts.numel())
+
+ return ClsLoss_g
+
+ def calc_advloss_d(self, x_s, x_t, xf_ts, xf_st, dr_adv_s, dr_adv_t, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts):
+ device = x_s.device
+ B_s = len(x_s)
+ B_t = len(x_t)
+
+ dr_adv_s = dr_adv_s.permute(0,2,1).reshape(-1,1)
+ dr_adv_t = dr_adv_t.permute(0,2,1).reshape(-1,1)
+ df_adv_ss = df_adv_ss.permute(0,2,1).reshape(-1,1)
+ df_adv_st = df_adv_st.permute(0,2,1).reshape(-1,1)
+ df_adv_tt = df_adv_tt.permute(0,2,1).reshape(-1,1)
+ df_adv_ts = df_adv_ts.permute(0,2,1).reshape(-1,1)
+
+ if self.loss_type=='wgan':
+ # Wasserstein GAN with gradient penalty (WGAN-GP)
+ AdvLoss_d_r = (
+ torch.sum(-dr_adv_s) +
+ torch.sum(-dr_adv_t)
+ ) / (dr_adv_s.numel() + dr_adv_t.numel())
+ AdvLoss_d_f = (
+ torch.sum(df_adv_ss) +
+ torch.sum(df_adv_st) +
+ torch.sum(df_adv_tt) +
+ torch.sum(df_adv_ts)
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+ AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f
+
+ elif self.loss_type=='lsgan':
+ # Least squares GAN (LSGAN)
+
+ AdvLoss_d_r = 0.5 * (
+ torch.sum((dr_adv_s - torch.ones_like(dr_adv_s))**2) +
+ torch.sum((dr_adv_t - torch.ones_like(dr_adv_t))**2)
+ ) / (dr_adv_s.numel() + dr_adv_t.numel())
+ AdvLoss_d_f = 0.5 * (
+ torch.sum(df_adv_ss**2) +
+ torch.sum(df_adv_st**2) +
+ torch.sum(df_adv_tt**2) +
+ torch.sum(df_adv_ts**2)
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+ AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f
+
+ elif self.loss_type=='cgan':
+ # Regular GAN with sigmoid cross-entropy criterion (CGAN)
+ AdvLoss_d_r = (
+ F.binary_cross_entropy_with_logits(dr_adv_s, torch.ones_like(dr_adv_s), reduction='sum') +
+ F.binary_cross_entropy_with_logits(dr_adv_t, torch.ones_like(dr_adv_t), reduction='sum')
+ ) / (dr_adv_s.numel() + dr_adv_t.numel())
+ AdvLoss_d_f = (
+ F.binary_cross_entropy_with_logits(df_adv_ss, torch.zeros_like(df_adv_ss), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_st, torch.zeros_like(df_adv_st), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_tt, torch.zeros_like(df_adv_tt), reduction='sum') +
+ F.binary_cross_entropy_with_logits(df_adv_ts, torch.zeros_like(df_adv_ts), reduction='sum')
+ ) / (df_adv_ss.numel() + df_adv_st.numel() + df_adv_tt.numel() + df_adv_ts.numel())
+ AdvLoss_d = AdvLoss_d_r + AdvLoss_d_f
+
+ # Gradient penalty loss
+ alpha_t = torch.rand(B_t, 1, 1, requires_grad=True).to(device)
+ interpolates = alpha_t * x_t + ((1 - alpha_t) * xf_ts)
+ interpolates = interpolates.to(device)
+ disc_interpolates, _ = self.dis(interpolates)
+ disc_interpolates = torch.sum(disc_interpolates)
+ gradients = torch.autograd.grad(outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+ create_graph=True, retain_graph=True, only_inputs=True)[0]
+ gradnorm = torch.sqrt(torch.sum(gradients * gradients, (1, 2)))
+ loss_gp_t = ((gradnorm - 1)**2).mean()
+
+ alpha_s = torch.rand(B_s, 1, 1, requires_grad=True).to(device)
+ interpolates = alpha_s * x_s + ((1 - alpha_s) * xf_st)
+ interpolates = interpolates.to(device)
+ interpolates = torch.autograd.Variable(interpolates, requires_grad=True)
+ disc_interpolates, _ = self.dis(interpolates)
+ disc_interpolates = torch.sum(disc_interpolates)
+ gradients = torch.autograd.grad(outputs=disc_interpolates,
+ inputs=interpolates,
+ grad_outputs=torch.ones(disc_interpolates.size()).to(device),
+ create_graph=True, retain_graph=True, only_inputs=True)[0]
+ gradnorm = torch.sqrt(torch.sum(gradients * gradients, (1, 2)))
+ loss_gp_s = ((gradnorm - 1)**2).mean()
+
+ GradLoss_d = loss_gp_s + loss_gp_t
+
+ return AdvLoss_d, GradLoss_d
+
+ def calc_clsloss_d(self, dr_cls_s, dr_cls_t, k_s, k_t):
+ device = dr_cls_s.device
+
+ dr_cls_s = dr_cls_s.permute(0,2,1).reshape(-1,self.n_spk)
+ dr_cls_t = dr_cls_t.permute(0,2,1).reshape(-1,self.n_spk)
+
+ cr_s = k_s*torch.ones(len(dr_cls_s), device=device, dtype=torch.long)
+ cr_t = k_t*torch.ones(len(dr_cls_t), device=device, dtype=torch.long)
+
+ ClsLoss_d = (
+ F.cross_entropy(dr_cls_s, cr_s, reduction='sum') +
+ F.cross_entropy(dr_cls_t, cr_t, reduction='sum')
+ ) / (dr_cls_s.numel() + dr_cls_t.numel())
+
+ return ClsLoss_d
+
+ def calc_gen_loss(self, x_s, x_t, k_s, k_t):
+ # Generator outputs
+ xf_ss = self.gen(x_s, k_s, k_s)
+ xf_ts = self.gen(x_t, k_s, k_t)
+ xf_tt = self.gen(x_t, k_t, k_t)
+ xf_st = self.gen(x_s, k_t, k_s)
+
+ # Discriminator outputs
+ df_adv_ss, df_cls_ss = self.dis(xf_ss)
+ df_adv_st, df_cls_st = self.dis(xf_st)
+ df_adv_tt, df_cls_tt = self.dis(xf_tt)
+ df_adv_ts, df_cls_ts = self.dis(xf_ts)
+
+ # Adversarial loss
+ AdvLoss_g = self.calc_advloss_g(df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts)
+
+ # Classifier loss
+ ClsLoss_g = self.calc_clsloss_g(df_cls_ss, df_cls_st, df_cls_tt, df_cls_ts, k_s, k_t)
+
+ # Cycle-consistency loss
+ CycLoss = (
+ torch.sum(torch.abs(x_s - self.gen(xf_st, k_s, k_t))) +
+ torch.sum(torch.abs(x_t - self.gen(xf_ts, k_t, k_s)))
+ ) / (x_s.numel() + x_t.numel())
+
+ # Reconstruction loss
+ RecLoss = (
+ torch.sum(torch.abs(x_s - xf_ss)) +
+ torch.sum(torch.abs(x_t - xf_tt))
+ ) / (x_s.numel() + x_t.numel())
+
+ return AdvLoss_g, ClsLoss_g, CycLoss, RecLoss
+
+ def calc_dis_loss(self, x_s, x_t, k_s, k_t):
+ device = x_s.device
+
+ # Generator outputs
+ xf_ss = self.gen(x_s, k_s, k_s)
+ xf_ts = self.gen(x_t, k_s, k_t)
+ xf_tt = self.gen(x_t, k_t, k_t)
+ xf_st = self.gen(x_s, k_t, k_s)
+
+ # Discriminator outputs
+ dr_adv_s, dr_cls_s = self.dis(x_s)
+ dr_adv_t, dr_cls_t = self.dis(x_t)
+ df_adv_ss, _ = self.dis(xf_ss)
+ df_adv_st, _ = self.dis(xf_st)
+ df_adv_tt, _ = self.dis(xf_tt)
+ df_adv_ts, _ = self.dis(xf_ts)
+
+ # Adversarial loss
+ AdvLoss_d, GradLoss_d = self.calc_advloss_d(x_s, x_t, xf_ts, xf_st, dr_adv_s, dr_adv_t, df_adv_ss, df_adv_st, df_adv_tt, df_adv_ts)
+
+ # Classifier loss
+ ClsLoss_d = self.calc_clsloss_d(dr_cls_s, dr_cls_t, k_s, k_t)
+
+ return AdvLoss_d, GradLoss_d, ClsLoss_d
diff --git a/run_talkingface.py b/run_talkingface.py
index 3989d566..a257a762 100644
--- a/run_talkingface.py
+++ b/run_talkingface.py
@@ -1,11 +1,19 @@
import argparse
-from talkingface.quick_start import run
+# from talkingface.quick_start import run
+from models.stargan import train_stargan_model
+
+
+def run(model_name, dataset_name, config_file_list=None, evaluate_model_file=None, train=False):
+ if train:
+ if model_name == 'stargan':
+ train_stargan_model()
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--model", "-m", type=str, default="BPR", help="name of models")
+ parser.add_argument("--model", "-m", type=str, default="stargan", help="name of models")
parser.add_argument(
- "--dataset", "-d", type=str, default=None, help="name of datasets"
+ "--dataset", "-d", type=str, default='vctk', help="name of datasets"
)
parser.add_argument("--evaluate_model_file", type=str, default=None, help="The model file you want to evaluate")
parser.add_argument("--config_files", type=str, default=None, help="config files")
diff --git a/talkingface/data/Arctic.yaml b/talkingface/data/Arctic.yaml
new file mode 100644
index 00000000..72a17db6
--- /dev/null
+++ b/talkingface/data/Arctic.yaml
@@ -0,0 +1,12 @@
+flen: 1024
+fmax: 7600
+fmin: 80
+fs: 16000
+fshift: 128
+num_mels: 80
+top_db: 30
+trim_silence: false
+stat_path: "./dataset/vctk/filelist/stat.pkl"
+data_path: "./dataset/vctk/data"
+model: "stargan_vc"
+train_filelist: "./dataset/vctk/filelist/train.txt"
\ No newline at end of file
diff --git a/talkingface/data/dataprocess/__init__.py b/talkingface/data/dataprocess/__init__.py
index 7c7aef87..9a724efa 100644
--- a/talkingface/data/dataprocess/__init__.py
+++ b/talkingface/data/dataprocess/__init__.py
@@ -1 +1,2 @@
-from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio, Wav2LipPreprocessForInference
\ No newline at end of file
+from talkingface.data.dataprocess.wav2lip_process import Wav2LipAudio, Wav2LipPreprocessForInference
+from talkingface.data.dataprocess.stargan_vc_process import StarganAudio
\ No newline at end of file
diff --git a/talkingface/data/dataprocess/stargan_vc_process.py b/talkingface/data/dataprocess/stargan_vc_process.py
new file mode 100644
index 00000000..63c106a7
--- /dev/null
+++ b/talkingface/data/dataprocess/stargan_vc_process.py
@@ -0,0 +1,84 @@
+
+import warnings
+import librosa
+import numpy as np
+import soundfile as sf
+
+
+
+class StarganAudio:
+ def __init__(self, config):
+ # 初始化配置参数。
+ self.kwargs = config
+
+ def logmelfilterbank(self, audio,
+ sampling_rate,
+ fft_size=1024,
+ hop_size=256,
+ win_length=None,
+ window="hann",
+ num_mels=80,
+ fmin=None,
+ fmax=None,
+ eps=1e-10,
+ ):
+ """Compute log-Mel filterbank feature.
+ Args:
+ audio (ndarray): Audio signal (T,).
+ sampling_rate (int): Sampling rate.
+ fft_size (int): FFT size.
+ hop_size (int): Hop size.
+ win_length (int): Window length. If set to None, it will be the same as fft_size.
+ window (str): Window function type.
+ num_mels (int): Number of mel basis.
+ fmin (int): Minimum frequency in mel basis calculation.
+ fmax (int): Maximum frequency in mel basis calculation.
+ eps (float): Epsilon value to avoid inf in log calculation.
+ Returns:
+ ndarray: Log Mel filterbank feature (#frames, num_mels).
+ """
+ # 获取振幅频谱
+ x_stft = librosa.stft(audio, n_fft=fft_size, hop_length=hop_size,
+ win_length=win_length, window=window, pad_mode="reflect")
+ spc = np.abs(x_stft).T # (#frames, #bins)
+
+ # 获取梅尔基数
+ fmin = 0 if fmin is None else fmin
+ fmax = sampling_rate / 2 if fmax is None else fmax
+ mel_basis = librosa.filters.mel(sampling_rate, fft_size, num_mels, fmin, fmax)
+
+ return np.log10(np.maximum(eps, np.dot(spc, mel_basis.T)))
+
+ def extract_melspec(self, src_filepath):
+ try:
+ warnings.filterwarnings('ignore')
+
+ # 从配置中提取参数。
+ trim_silence = self.kwargs['trim_silence']
+ top_db = self.kwargs['top_db']
+ flen = self.kwargs['flen']
+ fshift = self.kwargs['fshift']
+ fmin = self.kwargs['fmin']
+ fmax = self.kwargs['fmax']
+ num_mels = self.kwargs['num_mels']
+ fs = self.kwargs['fs']
+
+ # 读取音频文件。
+ audio, fs_ = sf.read(src_filepath)
+ if trim_silence:
+ # 如果需要,剪切静音部分。
+ audio, _ = librosa.effects.trim(audio, top_db=top_db, frame_length=2048, hop_length=512)
+ if fs != fs_:
+ # 如果需要,进行重采样。
+ audio = librosa.resample(audio, fs_, fs)
+ # 提取梅尔频谱。
+ melspec_raw = self.logmelfilterbank(audio,fs, fft_size=flen,hop_size=fshift,
+ fmin=fmin, fmax=fmax, num_mels=num_mels)
+ melspec_raw = melspec_raw.astype(np.float32)
+ melspec_raw = melspec_raw.T # n_mels x n_frame
+ return melspec_raw
+
+ except:
+ print(f"{src_filepath}...failed.")
+ return None
+
diff --git a/talkingface/data/dataset/__init__.py b/talkingface/data/dataset/__init__.py
index 3fd37538..d482fb31 100644
--- a/talkingface/data/dataset/__init__.py
+++ b/talkingface/data/dataset/__init__.py
@@ -1,2 +1,3 @@
from talkingface.data.dataset.wav2lip_dataset import Wav2LipDataset
-from talkingface.data.dataset.dataset import Dataset
\ No newline at end of file
+from talkingface.data.dataset.dataset import Dataset
+from talkingface.data.dataset.stargan_vc_dataset import StarganDataset
\ No newline at end of file
diff --git a/talkingface/data/dataset/compute_statistics_dataset.py b/talkingface/data/dataset/compute_statistics_dataset.py
new file mode 100644
index 00000000..a690c314
--- /dev/null
+++ b/talkingface/data/dataset/compute_statistics_dataset.py
@@ -0,0 +1,30 @@
+import os
+import h5py
+from tqdm import tqdm
+from sklearn.preprocessing import StandardScaler
+import pickle
+
+def walk_files(root, extension):
+ for path, dirs, files in os.walk(root):
+ for file in files:
+ if file.endswith(extension):
+ yield os.path.join(path, file)
+
+def compute_statistics(src, processor, stat_filepath="stat.pkl"):
+ melspec_scaler = StandardScaler()
+ filenames_all = [[os.path.join(src,d,t) for t in sorted(os.listdir(os.path.join(src,d)))] for d in os.listdir(src)]
+ filepath_list = list(walk_files(src, '.h5'))
+ for filepath_list in tqdm(filenames_all):
+ for filepath in filepath_list:
+ # 读取
+ melspec = processor.extract_melspec(filepath)
+ #import pdb;pdb.set_trace() # Breakpoint
+ melspec_scaler.partial_fit(melspec.T)
+
+ with open(stat_filepath, mode='wb') as f:
+ pickle.dump(melspec_scaler, f)
+ print("Saved.")
+
+
+if __name__ == '__main__':
+ pass
\ No newline at end of file
diff --git a/talkingface/data/dataset/stargan_vc_dataset.py b/talkingface/data/dataset/stargan_vc_dataset.py
new file mode 100644
index 00000000..8d3d8561
--- /dev/null
+++ b/talkingface/data/dataset/stargan_vc_dataset.py
@@ -0,0 +1,91 @@
+# 导入必要的模块和函数。
+
+import os
+from talkingface.data.dataprocess.stargan_vc_process import StarganAudio
+from talkingface.data.dataset.dataset import Dataset
+import math
+import numpy as np
+from sklearn.preprocessing import StandardScaler
+import pickle
+
+
+class StarganDataset(Dataset):
+ def __init__(self, config, file_list):
+ self.config = config
+ # 从配置中获取预处理数据的根目录,并列出其中的所有目录。
+ feat_dirs = os.listdir(self.config['preprocessed_root'])
+
+ with open(file_list) as f:
+ text = f.readlines()
+ file_list_ = []
+ for line in text:
+ line = line.strip()
+ file_list_.append(line)
+ # root下spk数量
+ # 为每个目录中的文件创建一个文件名列表。
+ self.filenames_all = [
+ [
+ os.path.join(self.config['preprocessed_root'],d,t) for t in sorted(os.listdir(os.path.join(self.config['preprocessed_root'], d))) if t in file_list_
+ ] for d in feat_dirs
+ ]
+ self.n_domain = len(self.filenames_all)
+ self.feat_dirs = feat_dirs
+ # 创建 StarganAudio 对象以处理音频数据。
+ self.process_audio = StarganAudio(config)
+
+ self.melspec_scaler = StandardScaler()
+ if os.path.exists(config['stat_path']):
+ with open(config['stat_path'], mode='rb') as f:
+ self.melspec_scaler = pickle.load(f)
+ else:
+ print("Melspec_scaler is None.")
+ self.melspec_scaler = None
+ def __len__(self):
+ return min(len(f) for f in self.filenames_all)
+
+ def __getitem__(self, idx):
+ melspec_list = []
+ # 遍历每个spk的平行wav。
+ for d in range(self.n_domain):
+ # 处理音频文件并获取梅尔频谱。
+ # print(self.filenames_all[d][idx])
+ melspec = self.process_audio.extract_melspec(self.filenames_all[d][idx]) # n_freq x n_time
+ if self.melspec_scaler is not None:
+ melspec = self.melspec_scaler.transform(melspec.T)
+ # 将梅尔频谱添加到列表中。
+ melspec_list.append(melspec.T)
+ # 返回包含所有领域梅尔频谱的列表。
+ return melspec_list
+ # return {"melspec_list": melspec_list}
+
+ def collate_fn(self, batch):
+ #batch[b][s]: melspec (n_freq x n_frame)
+ #b: batch size
+ #s: speaker ID
+
+ batchsize = len(batch)
+ n_spk = len(batch[0])
+ melspec_list = [[batch[b][s] for b in range(batchsize)] for s in range(n_spk)]
+ #melspec_list[s][b]: melspec (n_freq x n_frame)
+ #s: speaker ID
+ #b: batch size
+
+ n_freq = melspec_list[0][0].shape[0]
+
+ X_list = []
+ for s in range(n_spk):
+ maxlen=0
+ for b in range(batchsize):
+ if maxlen