Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 165 additions & 39 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import logging.handlers # For RotatingFileHandler
import shutil
import struct
import time
import uuid
import yaml # For loading presets
Expand Down Expand Up @@ -251,6 +252,35 @@ async def get_main_script():
# These functions support smart audio chunk concatenation with crossfading


def _create_wav_header(
sample_rate: int, channels: int = 1, bits_per_sample: int = 16
) -> bytes:
byte_rate = sample_rate * channels * (bits_per_sample // 8)
block_align = channels * (bits_per_sample // 8)
max_32bit_int = 0xFFFFFFFF
riff_chunk_size = max_32bit_int
data_size = riff_chunk_size - 36
header = b"RIFF"
header += struct.pack("<I", riff_chunk_size)
header += b"WAVEfmt "
header += struct.pack("<I", 16)
header += struct.pack("<H", 1)
header += struct.pack("<H", channels)
header += struct.pack("<I", sample_rate)
header += struct.pack("<I", byte_rate)
header += struct.pack("<H", block_align)
header += struct.pack("<H", bits_per_sample)
header += b"data"
header += struct.pack("<I", data_size)
return header


def _float32_to_pcm16(audio: np.ndarray) -> bytes:
audio = np.clip(audio, -1.0, 1.0)
audio_int16 = (audio * 32767).astype(np.int16)
return audio_int16.tobytes()


def _generate_equal_power_curves(n_samples: int):
"""
Generate equal-power crossfade curves using cos²/sin² functions.
Expand Down Expand Up @@ -500,7 +530,9 @@ async def save_settings_endpoint(request: Request):
)
message = "Settings saved successfully."
if restart_needed:
message += " A server restart may be required for some changes to take full effect."
message += (
" A server restart may be required for some changes to take effect."
)
return UpdateStatusResponse(message=message, restart_needed=restart_needed)
else:
logger.error(
Expand Down Expand Up @@ -917,7 +949,14 @@ async def custom_tts_endpoint(
status_code=400, detail="Text processing resulted in no usable chunks."
)

base_seed = request.seed if request.seed is not None else get_gen_default_seed()

for i, chunk in enumerate(text_chunks):

current_chunk_seed = base_seed
if base_seed is not None and base_seed >= 0:
current_chunk_seed = base_seed + i

logger.info(f"Synthesizing chunk {i+1}/{len(text_chunks)}...")
try:
chunk_audio_tensor, chunk_sr_from_engine = engine.synthesize(
Expand All @@ -942,9 +981,7 @@ async def custom_tts_endpoint(
if request.cfg_weight is not None
else get_gen_default_cfg_weight()
),
seed=(
request.seed if request.seed is not None else get_gen_default_seed()
),
seed=current_chunk_seed,
language=(
request.language
if request.language is not None
Expand Down Expand Up @@ -981,9 +1018,6 @@ async def custom_tts_endpoint(
)
perf_monitor.record(f"Speed factor applied to chunk {i+1}")

# ### MODIFICATION ###
# All other processing is REMOVED from the loop.
# We will process the final concatenated audio clip.
processed_audio_np = current_processed_audio_tensor.cpu().numpy().squeeze()
all_audio_segments_np.append(processed_audio_np)

Expand All @@ -1007,21 +1041,18 @@ async def custom_tts_endpoint(
)
try:
# ### SMART AUDIO STITCHING ###
# Local constants - adjust these values to tune stitching behavior
SENTENCE_PAUSE_MS = 200 # Desired audible silence between sentences
CROSSFADE_MS = 20 # Crossfade duration for smart mode (10-50ms recommended)
SAFETY_FADE_MS = 3 # Minimal edge fade for fallback mode (2-5ms)
ENABLE_DC_REMOVAL = False # Set True if you hear low-frequency thumps
DC_HIGHPASS_HZ = 15 # High-pass cutoff for DC removal
PEAK_NORMALIZE_THRESHOLD = 0.99 # Normalize if peak exceeds this
PEAK_NORMALIZE_TARGET = 0.95 # Target peak after normalization

# Read smart stitching toggle from config (defaults to True)
SENTENCE_PAUSE_MS = 200
CROSSFADE_MS = 20
SAFETY_FADE_MS = 3
ENABLE_DC_REMOVAL = False
DC_HIGHPASS_HZ = 15
PEAK_NORMALIZE_THRESHOLD = 0.99
PEAK_NORMALIZE_TARGET = 0.95

enable_smart_stitching = config_manager.get_bool(
"audio_processing.enable_crossfade", True
)

# --- Sample rate validation ---
if not engine_output_sample_rate or engine_output_sample_rate <= 0:
logger.error(
f"Invalid sample rate: {engine_output_sample_rate}, "
Expand All @@ -1034,22 +1065,16 @@ async def custom_tts_endpoint(
)

elif len(all_audio_segments_np) == 1:
# Single chunk - no stitching needed
final_audio_np = all_audio_segments_np[0]
logger.info("Single audio chunk - no stitching required")

elif enable_smart_stitching:
# --- Smart mode: true crossfading with silence insertion ---
fade_samples = int(CROSSFADE_MS / 1000 * engine_output_sample_rate)

# Calculate silence buffer with compensation for crossfade overlap
# Each crossfade removes fade_samples from silence (one at each end)
desired_silence_samples = int(
SENTENCE_PAUSE_MS / 1000 * engine_output_sample_rate
)
silence_buffer_samples = desired_silence_samples + (fade_samples * 2)

# Preprocess chunks: convert to float32 and optionally remove DC offset
chunks = []
for chunk in all_audio_segments_np:
processed = chunk.astype(np.float32, copy=True)
Expand All @@ -1059,18 +1084,11 @@ async def custom_tts_endpoint(
)
chunks.append(processed)

# Start with first chunk
result = chunks[0]

# Stitch remaining chunks with crossfaded silence gaps
for i in range(1, len(chunks)):
# Create silence buffer (oversized to compensate for crossfade overlap)
silence = np.zeros(silence_buffer_samples, dtype=np.float32)

# Crossfade: current result → silence (speech fades into silence)
result = _crossfade_with_overlap(result, silence, fade_samples)

# Crossfade: result → next chunk (silence fades into speech)
result = _crossfade_with_overlap(result, chunks[i], fade_samples)

final_audio_np = result
Expand All @@ -1080,7 +1098,6 @@ async def custom_tts_endpoint(
)

else:
# --- Fallback mode: minimal safety edge fades, no silence ---
fade_samples = int(SAFETY_FADE_MS / 1000 * engine_output_sample_rate)
num_chunks = len(all_audio_segments_np)

Expand All @@ -1092,8 +1109,8 @@ async def custom_tts_endpoint(
processed = _apply_edge_fades(
chunk,
fade_samples,
fade_in=(not is_first), # No fade-in on first chunk
fade_out=(not is_last), # No fade-out on last chunk
fade_in=(not is_first),
fade_out=(not is_last),
)
processed_chunks.append(processed)

Expand All @@ -1103,10 +1120,8 @@ async def custom_tts_endpoint(
f"{SAFETY_FADE_MS}ms linear fades"
)

# --- Ensure float32 dtype for all code paths ---
final_audio_np = final_audio_np.astype(np.float32, copy=False)

# --- Normalize to prevent clipping ---
peak_amplitude = np.abs(final_audio_np).max()
if peak_amplitude > PEAK_NORMALIZE_THRESHOLD:
final_audio_np = final_audio_np * (PEAK_NORMALIZE_TARGET / peak_amplitude)
Expand All @@ -1116,7 +1131,6 @@ async def custom_tts_endpoint(

perf_monitor.record("Audio chunks stitched")

# --- Global Audio Post-Processing (applied to complete stitched audio) ---
if config_manager.get_bool("audio_processing.enable_silence_trimming", False):
final_audio_np = utils.trim_lead_trail_silence(
final_audio_np, engine_output_sample_rate
Expand All @@ -1140,7 +1154,6 @@ async def custom_tts_endpoint(
)
perf_monitor.record("Global unvoiced removal applied")

# --- Warn about potentially conflicting settings ---
if enable_smart_stitching and config_manager.get_bool(
"audio_processing.enable_silence_trimming", False
):
Expand Down Expand Up @@ -1224,6 +1237,119 @@ async def custom_tts_endpoint(
)


async def stream_tts_generator(
request: CustomTTSRequest,
audio_prompt_path: str = None,
target_sample_rate: int = 24000,
):
chunk_size = request.chunk_size if request.chunk_size else 120
raw_chunks = utils.chunk_text_by_sentences(request.text, chunk_size)
text_chunks = [c for c in raw_chunks if c.strip()]

CROSSFADE_MS = 20

yield _create_wav_header(target_sample_rate)

buffered_tail: Optional[np.ndarray] = None
engine_sr: Optional[int] = None
fade_samples: int = 0

base_seed = request.seed if request.seed is not None else get_gen_default_seed()

for i, chunk_text in enumerate(text_chunks):
try:
current_chunk_seed = base_seed
if base_seed is not None and base_seed >= 0:
current_chunk_seed = base_seed + i

import asyncio

loop = asyncio.get_running_loop()

chunk_audio_tensor, sr = await loop.run_in_executor(
None,
lambda: engine.synthesize(
text=chunk_text,
audio_prompt_path=audio_prompt_path,
temperature=request.temperature or get_gen_default_temperature(),
exaggeration=request.exaggeration or get_gen_default_exaggeration(),
cfg_weight=request.cfg_weight or get_gen_default_cfg_weight(),
seed=current_chunk_seed,
language=request.language or get_gen_default_language(),
),
)

if chunk_audio_tensor is None:
continue

speed = request.speed_factor or get_gen_default_speed_factor()
if speed != 1.0:
chunk_audio_tensor, _ = utils.apply_speed_factor(
chunk_audio_tensor, sr, speed
)

audio_np = chunk_audio_tensor.cpu().numpy().squeeze().astype(np.float32)

if engine_sr is None:
engine_sr = sr
fade_samples = int(CROSSFADE_MS / 1000 * engine_sr)

if len(audio_np) < fade_samples * 2:
if buffered_tail is not None:
yield _float32_to_pcm16(buffered_tail)
buffered_tail = None
yield _float32_to_pcm16(audio_np)
continue

if buffered_tail is not None:
new_head = audio_np[:fade_samples]
fade_out, fade_in = _generate_equal_power_curves(fade_samples)
crossfaded_region = (buffered_tail * fade_out) + (new_head * fade_in)
yield _float32_to_pcm16(crossfaded_region)
body = audio_np[fade_samples:-fade_samples]
else:
body = audio_np[:-fade_samples]

yield _float32_to_pcm16(body)
buffered_tail = audio_np[-fade_samples:]

except Exception as e:
logger.error(f"Streaming error on chunk {i}: {e}")
break

if buffered_tail is not None:
yield _float32_to_pcm16(buffered_tail)


@app.post("/tts/stream", tags=["TTS Generation"])
async def tts_stream_endpoint(request: CustomTTSRequest):
logger.info(f"Stream request received for: {request.text[:30]}...")
audio_prompt_path_str = None
if request.voice_mode == "predefined" and request.predefined_voice_id:
p = (
get_predefined_voices_path(ensure_absolute=True)
/ request.predefined_voice_id
)
if p.exists():
audio_prompt_path_str = str(p)
elif request.voice_mode == "clone" and request.reference_audio_filename:
p = (
get_reference_audio_path(ensure_absolute=True)
/ request.reference_audio_filename
)
if p.exists():
audio_prompt_path_str = str(p)

if not engine.MODEL_LOADED:
raise HTTPException(status_code=503, detail="Model not loaded")

target_sr = get_audio_sample_rate()
return StreamingResponse(
stream_tts_generator(request, audio_prompt_path_str, target_sr),
media_type="audio/wav",
)


@app.post("/v1/audio/speech", tags=["OpenAI Compatible"])
async def openai_speech_endpoint(request: OpenAISpeechRequest):
# Determine the audio prompt path based on the voice parameter
Expand All @@ -1235,7 +1361,7 @@ async def openai_speech_endpoint(request: OpenAISpeechRequest):
if voice_path_predefined.is_file():
audio_prompt_path = voice_path_predefined
elif voice_path_reference.is_file():
audio_prompt_path = voice_path_reference
audio_prompt_path = reference_audio_path
else:
raise HTTPException(
status_code=404, detail=f"Voice file '{request.voice}' not found."
Expand Down
Loading