Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
.venv/*
*.pyc
checkpoint*/*
.gradio/certificate.pem
.python-version
*.wav
49 changes: 25 additions & 24 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,39 +9,40 @@ allow-direct-references = true
name = "playdiffusion"
version = "0.1.0"
description = "Diffusion model for speech inpainting and TTS"
requires-python = "==3.11.*"
requires-python = ">=3.11"
license = { text = "Apache-2.0" }
dependencies = [
"torch==2.6.0",
"torchaudio==2.6.0",
"numpy==1.24.3",
"fairseq2==0.4.4",
"nltk==3.9.1",
"torch>=2.8.0",
"torchaudio>=2.8.0",
"numpy>=1.24.3",
"fairseq2>=0.5.2",
"nltk>=3.9.1",
"syllables @ git+https://github.com/playht/python-syllables.git",
"jiwer==3.1.0",
"pydantic==2.11.5",
"soundfile==0.13.1",
"boto3==1.38.22",
"tqdm==4.67.1",
"python-decouple==3.8",
"safetensors==0.5.3",
"tokenizers==0.21.1",
"librosa==0.10.1",
"scipy==1.11.4",
"scikit-learn==1.3.2",
"einops==0.8.1",
"torchtune==0.6.1",
"torchao==0.11.0",
"huggingface-hub==0.31.4",
"unidecode==1.4.0",
"jiwer>=3.1.0",
"pydantic>=2.11.5",
"soundfile>=0.13.1",
"boto3>=1.38.22",
"tqdm>=4.67.1",
"python-decouple>=3.8",
"safetensors>=0.6.0",
"tokenizers>=0.21.1",
"librosa>=0.10.1",
"scipy>=1.11.4",
"scikit-learn>=1.3.2",
"einops>=0.8.1",
"torchtune>=0.6.1",
"torchao>=0.11.0",
"huggingface-hub>=0.31.4",
"unidecode>=1.4.0",
]

[project.optional-dependencies]
demo = [
"gradio==5.31.0",
"openai==1.82.0",
"gradio>=5.31.0",
"openai>=1.82.0",
"openai-whisper>=20230314",
"whisper-timestamped>=0.0.11",
"torchcodec>=0.7.0",
]

[tool.hatch.build.targets.wheel]
Expand Down
4 changes: 2 additions & 2 deletions src/playdiffusion/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,7 @@ def inpaint(self, input: InpaintInput):
print(f"Resampled wav: {resampled_wav.shape}")
self.timer("Resample")
with torch.inference_mode():
input_audio_tokens = self.mm.speech_tokenizer.waveform_to_units(
input_audio_tokens, _ = self.mm.speech_tokenizer.waveform_to_units(
resampled_wav.squeeze()
)
print(f"Input audio tokens: {input_audio_tokens.shape}")
Expand Down Expand Up @@ -844,7 +844,7 @@ def rvc(self, input: RVCInput):
print(f"Resampled wav: {resampled_wav.shape}")
self.timer("Resample")
with torch.inference_mode():
input_audio_tokens = self.mm.speech_tokenizer.waveform_to_units(
input_audio_tokens, _ = self.mm.speech_tokenizer.waveform_to_units(
resampled_wav.squeeze()
)
print(f"Input audio tokens: {input_audio_tokens.shape}")
Expand Down
3 changes: 1 addition & 2 deletions src/playdiffusion/models/speech_tokenizer/kmeans.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np
import torch
from fairseq2.typing import DataType, Device
from torch import Tensor, nn


class KmeansModel(nn.Module):
def __init__(self, km_path: str, device: Device, dtype: DataType):
def __init__(self, km_path: str, device: torch.device, dtype: torch.dtype):
super().__init__()
km_model = np.load(km_path)
centroids_numpy = km_model.transpose()
Expand Down
87 changes: 54 additions & 33 deletions src/playdiffusion/models/speech_tokenizer/speech_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from typing import List, Optional, Tuple, Union

import torch
from fairseq2.data import Collater
from fairseq2.models.sequence import SequenceBatch
from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
from fairseq2.typing import DataType, Device
from fairseq2.nn import BatchLayout
from torch.nn.utils.rnn import pad_sequence

from playdiffusion.models.speech_tokenizer.kmeans import KmeansModel
from playdiffusion.models.speech_tokenizer.xlsr_encoder import load_xlsr_encoder
Expand All @@ -32,8 +30,8 @@ def __init__(
self,
checkpoint: Union[str, None] = "data/checkpoints/xlsr2_1b_v2_custom.pt",
max_layer: Union[int, None] = 35,
device: Optional[Device] = None,
dtype: DataType = torch.float32,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32,
strict: bool = False,
eval: bool = True,
) -> None:
Expand Down Expand Up @@ -81,17 +79,20 @@ def dtype(self):
return next(self.parameters()).dtype

@torch.inference_mode()
def forward(self, batch: SequenceBatch) -> Tuple[torch.Tensor, PaddingMask]:
# The forward signature now accepts the padded sequences and the batch layout
def forward(self, seqs: torch.Tensor, layout: BatchLayout) -> torch.Tensor:
"""
Minimal re-implementation that assumes we only loaded `max_layer` layers.
This is better as it doesn't require the full model to be loaded.

:param batch:
The batch of sequences to process.
:param seqs:
The batch of padded sequences.
:param layout:
The layout of the batch (containing sequence lengths).
"""
seqs, padding_mask = self.model.encoder_frontend(batch.seqs, batch.padding_mask)
encoder_output, padding_mask = self.model.encoder(seqs, padding_mask)
return encoder_output, padding_mask
seqs, layout_out = self.model.encoder_frontend(seqs, layout)
encoder_output = self.model.encoder(seqs, layout_out)
return encoder_output, layout_out


class SpeechTokenizer(torch.nn.Module):
Expand All @@ -111,11 +112,10 @@ def __init__(
self,
checkpoint: Union[str, None] = "data/checkpoints/xlsr2_1b_v2_custom.pt",
kmeans_layer_checkpoint: str = "data/checkpoints/kmeans_10k.npy",
dtype: DataType = torch.float16,
device: Optional[Device] = None,
dtype: torch.dtype = torch.float16,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
self.collater = Collater(pad_value=1, pad_to_multiple=2)

if device is None:
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
Expand All @@ -134,25 +134,46 @@ def device(self):
def dtype(self):
return next(self.parameters()).dtype

def create_batch(self, x: BATCH_INPUT) -> SequenceBatch:
src = self.collater(x)
seqs, padding_mask = get_seqs_and_padding_mask(src)
batch = SequenceBatch(seqs=seqs, padding_mask=padding_mask)
return batch
def create_batch(self, x: BATCH_INPUT) -> Tuple[torch.Tensor, BatchLayout]:
if isinstance(x, torch.Tensor):
x = [x]

lens: List[int] = [int(t.shape[0]) for t in x]
# Original code padded with 1, but for an audio model 0 makes more sense
seqs = pad_sequence(x, batch_first=True, padding_value=0.0)
seqs = seqs.to(self.device, self.dtype)

B, T_max = int(seqs.size(0)), int(seqs.size(1))
seqs_layout = BatchLayout(shape=(B, T_max), seq_lens=lens, device=seqs.device)
return seqs, seqs_layout

@torch.inference_mode()
def forward(self, batch: SequenceBatch) -> Tuple[torch.Tensor, PaddingMask]:
self.cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.cuda_stream):
z, padding_mask = self.encoder(batch)
units = self.kmeans(z)
self.gpu_memory_manager.check_and_cleanup()
torch.cuda.current_stream().wait_stream(self.cuda_stream)
return units, padding_mask
def forward(self, seqs: torch.Tensor, seqs_layout: BatchLayout) -> tuple[torch.Tensor, BatchLayout]:

units = None
if torch.cuda.is_available():
self.cuda_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.cuda_stream):
z, unit_layout = self.encoder(seqs, seqs_layout)
units = self.kmeans(z)
self.gpu_memory_manager.check_and_cleanup()
torch.cuda.current_stream().wait_stream(self.cuda_stream)
else:
z, unit_layout = self.encoder(seqs, seqs_layout)
units = self.kmeans(z) # Doesn't modify layout
return units, unit_layout

@torch.inference_mode()
def waveform_to_units(self, waveform: torch.Tensor) -> torch.Tensor:
waveform = waveform.to(self.device).to(self.dtype)
batch = self.create_batch(waveform)
units, _ = self(batch)
return units
def waveform_to_units(self, waveform: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]) -> tuple[torch.Tensor, BatchLayout]:
"""
Converts a single waveform tensors or a list of waveform tensors into audio tokens.

Returns a batch of audio tokens [B, T] and the corresponding BatchLayout
Use unit_layout.seq_lens to get the length of the individual audio token tensors.

Output units are tokens of dtype torch.int64
0 <= token < num_embeddings (e.g., 10000 for kmeans_10k.npy)
"""
seqs, seqs_layout = self.create_batch(waveform)
units, unit_layout = self(seqs, seqs_layout)
return units, unit_layout
13 changes: 6 additions & 7 deletions src/playdiffusion/models/speech_tokenizer/xlsr_encoder.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Tuple, Union

from fairseq2.models.wav2vec2._factory import (
from fairseq2.models.wav2vec2 import (
Wav2Vec2Factory,
Wav2Vec2Config,
Wav2Vec2EncoderConfig,
Wav2Vec2Model
)
from fairseq2.models.wav2vec2._model import Wav2Vec2Model
from fairseq2.nn.transformer import TransformerNormOrder
from fairseq2.typing import DataType, Device

from fairseq2.models.transformer import TransformerNormOrder
import torch

def _encoder_xlsr2_1b_v2() -> Wav2Vec2EncoderConfig:
"""
Expand All @@ -28,7 +27,7 @@ def _encoder_xlsr2_1b_v2() -> Wav2Vec2EncoderConfig:
feature_extractor_layer_descs=layer_descs, # type: ignore
feature_extractor_bias=True,
feature_extractor_layer_norm_convs=True,
feature_gradient_scale=1.0,
feature_grad_scale=1.0,
num_fbank_channels=0,
fbank_stride=0,
sample_fbank_every_k=0,
Expand Down Expand Up @@ -74,7 +73,7 @@ def _xlsr2_1b_v2() -> Wav2Vec2Config:


def load_xlsr_encoder(
device: Device, dtype: DataType, max_layer: Union[int, None] = 35
device: torch.device, dtype: torch.dtype, max_layer: Union[int, None] = 35
) -> Tuple[Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2EncoderConfig]:
"""
build_xlsr_1b_v2
Expand Down