Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 58 additions & 7 deletions backend/backends/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ async def generate(
"""
Generate audio from text using voice prompt.

Automatically splits long text into sentence-boundary chunks,
generates each chunk individually, and concatenates with crossfade.
Optionally upsamples to a higher sample rate based on quality setting.

Args:
text: Text to synthesize
voice_prompt: Voice prompt dictionary from create_voice_prompt
Expand All @@ -320,29 +324,76 @@ async def generate(
Returns:
Tuple of (audio_array, sample_rate)
"""
from ..utils.chunked_tts import (
split_text_into_chunks,
concatenate_audio_chunks,
resample_audio,
_tts_settings,
)

# Load model
await self.load_model_async(None)

max_chars = _tts_settings["max_chunk_chars"]
chunks = split_text_into_chunks(text, max_chars)

if len(chunks) <= 1:
# Short text -- single-shot generation (fast path)
audio, sample_rate = await self._generate_single(
text, voice_prompt, seed, instruct,
)
else:
# Long text -- chunked generation
print(f"[chunked-tts] Splitting {len(text)} chars into {len(chunks)} chunks")
audio_chunks: List[np.ndarray] = []
sample_rate = None

for i, chunk_text in enumerate(chunks):
print(f"[chunked-tts] Generating chunk {i + 1}/{len(chunks)} ({len(chunk_text)} chars)")
chunk_audio, chunk_sr = await self._generate_single(
chunk_text, voice_prompt, seed, instruct,
)
audio_chunks.append(np.asarray(chunk_audio, dtype=np.float32))
if sample_rate is None:
sample_rate = chunk_sr

audio = concatenate_audio_chunks(audio_chunks, sample_rate)

# Quality-based resampling
quality = _tts_settings["quality"]
if quality == "high":
target_rate = _tts_settings["upsample_rate"]
if sample_rate != target_rate:
audio = resample_audio(
np.asarray(audio, dtype=np.float32), sample_rate, target_rate,
)
sample_rate = target_rate

return audio, sample_rate

async def _generate_single(
self,
text: str,
voice_prompt: dict,
seed: Optional[int] = None,
instruct: Optional[str] = None,
) -> Tuple[np.ndarray, int]:
"""Generate audio for a single text segment (no chunking)."""

def _generate_sync():
"""Run synchronous generation in thread pool."""
# Set seed if provided
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)

# Generate audio - this is the blocking operation
wavs, sample_rate = self.model.generate_voice_clone(
text=text,
voice_clone_prompt=voice_prompt,
instruct=instruct,
)
return wavs[0], sample_rate

# Run blocking inference in thread pool to avoid blocking event loop
audio, sample_rate = await asyncio.to_thread(_generate_sync)

return audio, sample_rate
return await asyncio.to_thread(_generate_sync)


class PyTorchSTTBackend:
Expand Down
17 changes: 17 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,6 +1109,23 @@ async def get_sample_audio(sample_id: str, db: Session = Depends(get_db)):
# MODEL MANAGEMENT
# ============================================

@app.get("/tts/settings")
async def get_tts_settings():
"""Get current TTS chunking and quality settings."""
from .utils.chunked_tts import get_tts_settings as _get_settings
return _get_settings()


@app.post("/tts/settings")
async def update_tts_settings(request: models.TTSSettingsUpdate):
"""Update TTS quality and chunking settings at runtime."""
from .utils.chunked_tts import update_tts_settings as _update_settings
try:
return _update_settings(request.model_dump(exclude_none=True))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))


@app.post("/models/load")
async def load_model(model_size: str = "1.7B"):
"""Manually load TTS model."""
Expand Down
8 changes: 7 additions & 1 deletion backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,16 @@ class Config:
from_attributes = True


class TTSSettingsUpdate(BaseModel):
"""Request model for updating TTS settings."""
quality: Optional[str] = Field(None, pattern="^(standard|high)$")
max_chunk_chars: Optional[int] = Field(None, ge=100, le=5000)


class GenerationRequest(BaseModel):
"""Request model for voice generation."""
profile_id: str
text: str = Field(..., min_length=1, max_length=5000)
text: str = Field(..., min_length=1, max_length=50000)
language: str = Field(default="en", pattern="^(zh|en|ja|ko|de|fr|ru|pt|es|it)$")
seed: Optional[int] = Field(None, ge=0)
model_size: Optional[str] = Field(default="1.7B", pattern="^(1\\.7B|0\\.6B)$")
Expand Down
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ qwen-tts>=0.0.5
librosa>=0.10.0
soundfile>=0.12.0
numpy>=1.24.0
soxr>=0.3.0

# Utilities
python-multipart>=0.0.6
Expand Down
176 changes: 176 additions & 0 deletions backend/utils/chunked_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
"""
Chunked TTS generation with quality selection.

Splits long text into sentence-boundary chunks, generates audio per-chunk,
and concatenates with crossfade. Optionally upsamples to 44.1kHz for
higher quality output.

Environment variables:
TTS_QUALITY: "standard" (24kHz native) or "high" (44.1kHz upsampled)
TTS_MAX_CHUNK_CHARS: Max characters per chunk (default 800)
TTS_UPSAMPLE_RATE: Target sample rate for high quality (default 44100)
"""

import logging
import os
import re
from typing import List

import numpy as np

logger = logging.getLogger("voicebox.chunked-tts")

# ---------------------------------------------------------------------------
# Runtime-mutable settings
# ---------------------------------------------------------------------------

_tts_settings = {
"quality": os.getenv("TTS_QUALITY", "standard"),
"max_chunk_chars": int(os.getenv("TTS_MAX_CHUNK_CHARS", "800")),
"upsample_rate": int(os.getenv("TTS_UPSAMPLE_RATE", "44100")),
}

QUALITY_RATES = {
"standard": 24000, # Qwen3-TTS native sample rate
"high": 44100, # CD-quality upsampled via soxr
}


def get_tts_settings() -> dict:
"""Return current TTS chunking/quality settings."""
quality = _tts_settings["quality"]
return {
"quality": quality,
"sample_rate": QUALITY_RATES.get(quality, 24000),
"max_chunk_chars": _tts_settings["max_chunk_chars"],
"available_qualities": list(QUALITY_RATES.keys()),
}


def update_tts_settings(updates: dict) -> dict:
"""Update TTS settings at runtime. Returns new settings."""
if "quality" in updates:
q = updates["quality"]
if q not in QUALITY_RATES:
raise ValueError(
f"Invalid quality '{q}'. Must be one of {list(QUALITY_RATES.keys())}"
)
_tts_settings["quality"] = q
if "max_chunk_chars" in updates:
val = int(updates["max_chunk_chars"])
if val < 100 or val > 5000:
raise ValueError("max_chunk_chars must be between 100 and 5000")
_tts_settings["max_chunk_chars"] = val
return get_tts_settings()


# ---------------------------------------------------------------------------
# Text splitting
# ---------------------------------------------------------------------------

def split_text_into_chunks(text: str, max_chars: int = 800) -> List[str]:
"""Split text at sentence boundaries, with clause and word fallbacks.

Priority: sentence-end (.!?) > clause boundary (;:,) > whitespace > hard cut.
"""
text = text.strip()
if not text:
return []
if len(text) <= max_chars:
return [text]

chunks: List[str] = []
remaining = text

while remaining:
remaining = remaining.strip()
if not remaining:
break
if len(remaining) <= max_chars:
chunks.append(remaining)
break

segment = remaining[:max_chars]

# Try to find last sentence end
split_pos = _find_last_sentence_end(segment)
if split_pos == -1:
split_pos = _find_last_clause_boundary(segment)
if split_pos == -1:
split_pos = segment.rfind(" ")
if split_pos == -1:
split_pos = max_chars - 1

chunk = remaining[: split_pos + 1].strip()
if chunk:
chunks.append(chunk)
remaining = remaining[split_pos + 1 :]

return chunks


def _find_last_sentence_end(text: str) -> int:
best = -1
for m in re.finditer(r"[.!?](?:\s|$)", text):
best = m.start()
return best


def _find_last_clause_boundary(text: str) -> int:
best = -1
for m in re.finditer(r"[;:,\u2014](?:\s|$)", text):
best = m.start()
return best


# ---------------------------------------------------------------------------
# Audio concatenation
# ---------------------------------------------------------------------------

def concatenate_audio_chunks(
chunks: List[np.ndarray],
sr: int,
crossfade_ms: int = 50,
) -> np.ndarray:
"""Concatenate audio arrays with a short crossfade to avoid clicks."""
if not chunks:
return np.array([], dtype=np.float32)
if len(chunks) == 1:
return chunks[0]

crossfade_samples = int(sr * crossfade_ms / 1000)
result = chunks[0].copy()

for chunk in chunks[1:]:
if len(chunk) == 0:
continue
overlap = min(crossfade_samples, len(result), len(chunk))
if overlap > 0:
fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32)
fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32)
result[-overlap:] = result[-overlap:] * fade_out + chunk[:overlap] * fade_in
result = np.concatenate([result, chunk[overlap:]])
else:
result = np.concatenate([result, chunk])

return result


# ---------------------------------------------------------------------------
# Resampling
# ---------------------------------------------------------------------------

def resample_audio(audio: np.ndarray, src_rate: int, dst_rate: int) -> np.ndarray:
"""Resample audio using soxr (VHQ), with linear-interp fallback."""
if src_rate == dst_rate:
return audio
try:
import soxr

return soxr.resample(audio, src_rate, dst_rate, quality="VHQ")
except ImportError:
logger.warning("soxr not installed; falling back to linear interpolation")
ratio = dst_rate / src_rate
new_len = int(len(audio) * ratio)
indices = np.linspace(0, len(audio) - 1, new_len)
return np.interp(indices, np.arange(len(audio)), audio).astype(np.float32)