diff --git a/backend/backends/pytorch_backend.py b/backend/backends/pytorch_backend.py index 26f3872..0f2135b 100644 --- a/backend/backends/pytorch_backend.py +++ b/backend/backends/pytorch_backend.py @@ -310,6 +310,10 @@ async def generate( """ Generate audio from text using voice prompt. + Automatically splits long text into sentence-boundary chunks, + generates each chunk individually, and concatenates with crossfade. + Optionally upsamples to a higher sample rate based on quality setting. + Args: text: Text to synthesize voice_prompt: Voice prompt dictionary from create_voice_prompt @@ -320,18 +324,68 @@ async def generate( Returns: Tuple of (audio_array, sample_rate) """ + from ..utils.chunked_tts import ( + split_text_into_chunks, + concatenate_audio_chunks, + resample_audio, + _tts_settings, + ) + # Load model await self.load_model_async(None) + max_chars = _tts_settings["max_chunk_chars"] + chunks = split_text_into_chunks(text, max_chars) + + if len(chunks) <= 1: + # Short text -- single-shot generation (fast path) + audio, sample_rate = await self._generate_single( + text, voice_prompt, seed, instruct, + ) + else: + # Long text -- chunked generation + print(f"[chunked-tts] Splitting {len(text)} chars into {len(chunks)} chunks") + audio_chunks: List[np.ndarray] = [] + sample_rate = None + + for i, chunk_text in enumerate(chunks): + print(f"[chunked-tts] Generating chunk {i + 1}/{len(chunks)} ({len(chunk_text)} chars)") + chunk_audio, chunk_sr = await self._generate_single( + chunk_text, voice_prompt, seed, instruct, + ) + audio_chunks.append(np.asarray(chunk_audio, dtype=np.float32)) + if sample_rate is None: + sample_rate = chunk_sr + + audio = concatenate_audio_chunks(audio_chunks, sample_rate) + + # Quality-based resampling + quality = _tts_settings["quality"] + if quality == "high": + target_rate = _tts_settings["upsample_rate"] + if sample_rate != target_rate: + audio = resample_audio( + np.asarray(audio, dtype=np.float32), sample_rate, target_rate, + ) + sample_rate = target_rate + + return audio, sample_rate + + async def _generate_single( + self, + text: str, + voice_prompt: dict, + seed: Optional[int] = None, + instruct: Optional[str] = None, + ) -> Tuple[np.ndarray, int]: + """Generate audio for a single text segment (no chunking).""" + def _generate_sync(): - """Run synchronous generation in thread pool.""" - # Set seed if provided if seed is not None: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) - # Generate audio - this is the blocking operation wavs, sample_rate = self.model.generate_voice_clone( text=text, voice_clone_prompt=voice_prompt, @@ -339,10 +393,7 @@ def _generate_sync(): ) return wavs[0], sample_rate - # Run blocking inference in thread pool to avoid blocking event loop - audio, sample_rate = await asyncio.to_thread(_generate_sync) - - return audio, sample_rate + return await asyncio.to_thread(_generate_sync) class PyTorchSTTBackend: diff --git a/backend/main.py b/backend/main.py index 59fb9e1..6c8d876 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1109,6 +1109,23 @@ async def get_sample_audio(sample_id: str, db: Session = Depends(get_db)): # MODEL MANAGEMENT # ============================================ +@app.get("/tts/settings") +async def get_tts_settings(): + """Get current TTS chunking and quality settings.""" + from .utils.chunked_tts import get_tts_settings as _get_settings + return _get_settings() + + +@app.post("/tts/settings") +async def update_tts_settings(request: models.TTSSettingsUpdate): + """Update TTS quality and chunking settings at runtime.""" + from .utils.chunked_tts import update_tts_settings as _update_settings + try: + return _update_settings(request.model_dump(exclude_none=True)) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + @app.post("/models/load") async def load_model(model_size: str = "1.7B"): """Manually load TTS model.""" diff --git a/backend/models.py b/backend/models.py index 59e4540..a9dd35c 100644 --- a/backend/models.py +++ b/backend/models.py @@ -49,10 +49,16 @@ class Config: from_attributes = True +class TTSSettingsUpdate(BaseModel): + """Request model for updating TTS settings.""" + quality: Optional[str] = Field(None, pattern="^(standard|high)$") + max_chunk_chars: Optional[int] = Field(None, ge=100, le=5000) + + class GenerationRequest(BaseModel): """Request model for voice generation.""" profile_id: str - text: str = Field(..., min_length=1, max_length=5000) + text: str = Field(..., min_length=1, max_length=50000) language: str = Field(default="en", pattern="^(zh|en|ja|ko|de|fr|ru|pt|es|it)$") seed: Optional[int] = Field(None, ge=0) model_size: Optional[str] = Field(default="1.7B", pattern="^(1\\.7B|0\\.6B)$") diff --git a/backend/requirements.txt b/backend/requirements.txt index e0f6ded..259690a 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -18,6 +18,7 @@ qwen-tts>=0.0.5 librosa>=0.10.0 soundfile>=0.12.0 numpy>=1.24.0 +soxr>=0.3.0 # Utilities python-multipart>=0.0.6 diff --git a/backend/utils/chunked_tts.py b/backend/utils/chunked_tts.py new file mode 100644 index 0000000..6eb3962 --- /dev/null +++ b/backend/utils/chunked_tts.py @@ -0,0 +1,176 @@ +""" +Chunked TTS generation with quality selection. + +Splits long text into sentence-boundary chunks, generates audio per-chunk, +and concatenates with crossfade. Optionally upsamples to 44.1kHz for +higher quality output. + +Environment variables: + TTS_QUALITY: "standard" (24kHz native) or "high" (44.1kHz upsampled) + TTS_MAX_CHUNK_CHARS: Max characters per chunk (default 800) + TTS_UPSAMPLE_RATE: Target sample rate for high quality (default 44100) +""" + +import logging +import os +import re +from typing import List + +import numpy as np + +logger = logging.getLogger("voicebox.chunked-tts") + +# --------------------------------------------------------------------------- +# Runtime-mutable settings +# --------------------------------------------------------------------------- + +_tts_settings = { + "quality": os.getenv("TTS_QUALITY", "standard"), + "max_chunk_chars": int(os.getenv("TTS_MAX_CHUNK_CHARS", "800")), + "upsample_rate": int(os.getenv("TTS_UPSAMPLE_RATE", "44100")), +} + +QUALITY_RATES = { + "standard": 24000, # Qwen3-TTS native sample rate + "high": 44100, # CD-quality upsampled via soxr +} + + +def get_tts_settings() -> dict: + """Return current TTS chunking/quality settings.""" + quality = _tts_settings["quality"] + return { + "quality": quality, + "sample_rate": QUALITY_RATES.get(quality, 24000), + "max_chunk_chars": _tts_settings["max_chunk_chars"], + "available_qualities": list(QUALITY_RATES.keys()), + } + + +def update_tts_settings(updates: dict) -> dict: + """Update TTS settings at runtime. Returns new settings.""" + if "quality" in updates: + q = updates["quality"] + if q not in QUALITY_RATES: + raise ValueError( + f"Invalid quality '{q}'. Must be one of {list(QUALITY_RATES.keys())}" + ) + _tts_settings["quality"] = q + if "max_chunk_chars" in updates: + val = int(updates["max_chunk_chars"]) + if val < 100 or val > 5000: + raise ValueError("max_chunk_chars must be between 100 and 5000") + _tts_settings["max_chunk_chars"] = val + return get_tts_settings() + + +# --------------------------------------------------------------------------- +# Text splitting +# --------------------------------------------------------------------------- + +def split_text_into_chunks(text: str, max_chars: int = 800) -> List[str]: + """Split text at sentence boundaries, with clause and word fallbacks. + + Priority: sentence-end (.!?) > clause boundary (;:,) > whitespace > hard cut. + """ + text = text.strip() + if not text: + return [] + if len(text) <= max_chars: + return [text] + + chunks: List[str] = [] + remaining = text + + while remaining: + remaining = remaining.strip() + if not remaining: + break + if len(remaining) <= max_chars: + chunks.append(remaining) + break + + segment = remaining[:max_chars] + + # Try to find last sentence end + split_pos = _find_last_sentence_end(segment) + if split_pos == -1: + split_pos = _find_last_clause_boundary(segment) + if split_pos == -1: + split_pos = segment.rfind(" ") + if split_pos == -1: + split_pos = max_chars - 1 + + chunk = remaining[: split_pos + 1].strip() + if chunk: + chunks.append(chunk) + remaining = remaining[split_pos + 1 :] + + return chunks + + +def _find_last_sentence_end(text: str) -> int: + best = -1 + for m in re.finditer(r"[.!?](?:\s|$)", text): + best = m.start() + return best + + +def _find_last_clause_boundary(text: str) -> int: + best = -1 + for m in re.finditer(r"[;:,\u2014](?:\s|$)", text): + best = m.start() + return best + + +# --------------------------------------------------------------------------- +# Audio concatenation +# --------------------------------------------------------------------------- + +def concatenate_audio_chunks( + chunks: List[np.ndarray], + sr: int, + crossfade_ms: int = 50, +) -> np.ndarray: + """Concatenate audio arrays with a short crossfade to avoid clicks.""" + if not chunks: + return np.array([], dtype=np.float32) + if len(chunks) == 1: + return chunks[0] + + crossfade_samples = int(sr * crossfade_ms / 1000) + result = chunks[0].copy() + + for chunk in chunks[1:]: + if len(chunk) == 0: + continue + overlap = min(crossfade_samples, len(result), len(chunk)) + if overlap > 0: + fade_out = np.linspace(1.0, 0.0, overlap, dtype=np.float32) + fade_in = np.linspace(0.0, 1.0, overlap, dtype=np.float32) + result[-overlap:] = result[-overlap:] * fade_out + chunk[:overlap] * fade_in + result = np.concatenate([result, chunk[overlap:]]) + else: + result = np.concatenate([result, chunk]) + + return result + + +# --------------------------------------------------------------------------- +# Resampling +# --------------------------------------------------------------------------- + +def resample_audio(audio: np.ndarray, src_rate: int, dst_rate: int) -> np.ndarray: + """Resample audio using soxr (VHQ), with linear-interp fallback.""" + if src_rate == dst_rate: + return audio + try: + import soxr + + return soxr.resample(audio, src_rate, dst_rate, quality="VHQ") + except ImportError: + logger.warning("soxr not installed; falling back to linear interpolation") + ratio = dst_rate / src_rate + new_len = int(len(audio) * ratio) + indices = np.linspace(0, len(audio) - 1, new_len) + return np.interp(indices, np.arange(len(audio)), audio).astype(np.float32)