Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions .ipynb_checkpoints/download-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 6,
"id": "ea3972c8-dad2-4638-905f-aab636621fa7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"'wget' 不是内部或外部命令,也不是可运行的程序\n",
"或批处理文件。\n"
]
}
],
"source": [
"import requests\n",
"import os\n",
"\n",
"# 文件的URL\n",
"url = \"https://hf-mirror.com/yl4579/StyleTTS2-LJSpeech/blob/main/Models/LJSpeech/epoch_2nd_00100.pth\"\n",
"\n",
"# 发起GET请求并保存文件\n",
"response = requests.get(url)\n",
"file_path = 'checkpoints/epoch_2nd_00100.pth'\n",
"with open(file_path, 'wb') as file:\n",
" file.write(response.content)\n",
"\n",
"print(\"文件下载完成\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b90e9033-1f86-4c69-a4ab-d5c03a2b5741",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
253 changes: 253 additions & 0 deletions .ipynb_checkpoints/losses-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
from transformers import AutoModel

class SpectralConvergengeLoss(torch.nn.Module):
"""Spectral convergence loss module."""

def __init__(self):
"""Initilize spectral convergence loss module."""
super(SpectralConvergengeLoss, self).__init__()

def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)

class STFTLoss(torch.nn.Module):
"""STFT loss module."""

def __init__(self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window):
"""Initialize STFT loss module."""
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.to_mel = torchaudio.transforms.MelSpectrogram(sample_rate=24000, n_fft=fft_size, win_length=win_length, hop_length=shift_size, window_fn=window)

self.spectral_convergenge_loss = SpectralConvergengeLoss()

def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = self.to_mel(x)
mean, std = -4, 4
x_mag = (torch.log(1e-5 + x_mag) - mean) / std

y_mag = self.to_mel(y)
mean, std = -4, 4
y_mag = (torch.log(1e-5 + y_mag) - mean) / std

sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
return sc_loss


class MultiResolutionSTFTLoss(torch.nn.Module):
"""Multi resolution STFT loss module."""

def __init__(self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window=torch.hann_window):
"""Initialize Multi resolution STFT loss module.
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_lengths (list): List of window lengths.
window (str): Window function type.
"""
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLoss(fs, ss, wl, window)]

def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
sc_loss = 0.0
for f in self.stft_losses:
sc_l = f(x, y)
sc_loss += sc_l
sc_loss /= len(self.stft_losses)

return sc_loss


def feature_loss(fmap_r, fmap_g):
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))

return loss*2


def discriminator_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1-dr)**2)
g_loss = torch.mean(dg**2)
loss += (r_loss + g_loss)
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())

return loss, r_losses, g_losses


def generator_loss(disc_outputs):
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1-dg)**2)
gen_losses.append(l)
loss += l

return loss, gen_losses

""" https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
tau = 0.04
m_DG = torch.median((dr-dg))
L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
loss += tau - F.relu(tau - L_rel)
return loss

def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
loss = 0
for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
tau = 0.04
m_DG = torch.median((dr-dg))
L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
loss += tau - F.relu(tau - L_rel)
return loss

class GeneratorLoss(torch.nn.Module):

def __init__(self, mpd, msd):
super(GeneratorLoss, self).__init__()
self.mpd = mpd
self.msd = msd

def forward(self, y, y_hat):
y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)

loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)

loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel

return loss_gen_all.mean()

class DiscriminatorLoss(torch.nn.Module):

def __init__(self, mpd, msd):
super(DiscriminatorLoss, self).__init__()
self.mpd = mpd
self.msd = msd

def forward(self, y, y_hat):
# MPD
y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MSD
y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)

loss_rel = discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)


d_loss = loss_disc_s + loss_disc_f + loss_rel

return d_loss.mean()


class WavLMLoss(torch.nn.Module):

def __init__(self, model, wd, model_sr, slm_sr=16000):
super(WavLMLoss, self).__init__()
self.wavlm = AutoModel.from_pretrained(model)
self.wd = wd
self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)

def forward(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(input_values=y_rec_16.squeeze(), output_hidden_states=True).hidden_states

floss = 0
for er, eg in zip(wav_embeddings, y_rec_embeddings):
floss += torch.mean(torch.abs(er - eg))

return floss.mean()

def generator(self, y_rec):
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states
y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
y_df_hat_g = self.wd(y_rec_embeddings)
loss_gen = torch.mean((1-y_df_hat_g)**2)

return loss_gen

def discriminator(self, wav, y_rec):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
y_rec_16 = self.resample(y_rec)
y_rec_embeddings = self.wavlm(input_values=y_rec_16, output_hidden_states=True).hidden_states

y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)
y_rec_embeddings = torch.stack(y_rec_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)

y_d_rs = self.wd(y_embeddings)
y_d_gs = self.wd(y_rec_embeddings)

y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs

r_loss = torch.mean((1-y_df_hat_r)**2)
g_loss = torch.mean((y_df_hat_g)**2)

loss_disc_f = r_loss + g_loss

return loss_disc_f.mean()

def discriminator_forward(self, wav):
with torch.no_grad():
wav_16 = self.resample(wav)
wav_embeddings = self.wavlm(input_values=wav_16, output_hidden_states=True).hidden_states
y_embeddings = torch.stack(wav_embeddings, dim=1).transpose(-1, -2).flatten(start_dim=1, end_dim=2)

y_d_rs = self.wd(y_embeddings)

return y_d_rs
24 changes: 24 additions & 0 deletions .ipynb_checkpoints/run_talkingface-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import argparse
from talkingface.quick_start import run

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", "-m", type=str, default="BPR", help="name of models")
parser.add_argument(
"--dataset", "-d", type=str, default=None, help="name of datasets"
)
parser.add_argument("--evaluate_model_file", type=str, default=None, help="The model file you want to evaluate")
parser.add_argument("--config_files", type=str, default=None, help="config files")


args, _ = parser.parse_known_args()

config_file_list = (
args.config_files.strip().split(" ") if args.config_files else None
)
run(
args.model,
args.dataset,
#config_file_list=config_file_list,
#evaluate_model_file=args.evaluate_model_file
)
Loading