diff --git a/backend/config.py b/backend/config.py index 76416ce..103d1b0 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,28 +1,150 @@ -""" -Vocalis Configuration Module - -Loads and provides access to configuration settings from environment variables -and the .env file. -""" +"""Vocalis configuration helpers.""" +import json +import logging import os +from pathlib import Path +from typing import Any, Dict, Optional, List + from dotenv import load_dotenv -from typing import Dict, Any # Load environment variables from .env file load_dotenv() +# Logger for configuration warnings +logger = logging.getLogger(__name__) + +# Repository paths +_REPO_ROOT = Path(__file__).resolve().parent.parent +MODEL_CACHE_DIR = Path( + os.getenv("MODEL_CACHE_DIR", str(_REPO_ROOT / "models")) +).expanduser().resolve() + + +def _load_json_env(var_name: str, default: Dict[str, Any]) -> Dict[str, Any]: + """Parse a JSON object stored in an environment variable.""" + + raw_value = os.getenv(var_name) + + if not raw_value: + return default + + try: + parsed = json.loads(raw_value) + except json.JSONDecodeError: + logger.warning("Unable to decode JSON for %s. Using default.", var_name) + return default + + if not isinstance(parsed, dict): + logger.warning( + "Environment variable %s must be a JSON object. Using default.", + var_name, + ) + return default + + return parsed + + +# --------------------------------------------------------------------------- +# Model presets +# --------------------------------------------------------------------------- + +AVAILABLE_STT_MODELS: List[Dict[str, Any]] = [ + { + "id": "kyutai/stt-1b-en_fr", + "label": "Kyutai STT 1B (English/French)", + "description": "Low-latency bilingual streaming model optimised for conversational speech.", + "generation_config": {"max_new_tokens": 256}, + "torch_dtype": "float16", + "sample_rate": 16000, + }, + { + "id": "kyutai/stt-2.6b-en", + "label": "Kyutai STT 2.6B (English)", + "description": "Highest accuracy Kyutai release with slightly higher latency footprint.", + "generation_config": {"max_new_tokens": 256}, + "torch_dtype": "float16", + "sample_rate": 16000, + }, +] + +AVAILABLE_TTS_MODELS: List[Dict[str, Any]] = [ + { + "id": "sesame/csm-1b", + "label": "Seseme CSM 1B", + "description": "Codec-style high fidelity model (requires gated access).", + "provider": "huggingface-local", + "output_format": "wav", + }, + { + "id": "kyutai/tts-0.75b-en-public", + "label": "Kyutai TTS 0.75B (English)", + "description": "Entry level Kyutai TTS for lightweight hardware.", + "provider": "huggingface-local", + "output_format": "wav", + }, + { + "id": "kyutai/tts-1.6b-en_fr", + "label": "Kyutai TTS 1.6B (English/French)", + "description": "Mid-sized Kyutai TTS with bilingual support.", + "provider": "huggingface-local", + "output_format": "wav", + }, +] + + +def _find_model_option(options: List[Dict[str, Any]], model_id: str) -> Optional[Dict[str, Any]]: + for option in options: + if option.get("id") == model_id: + return option + return None + + +def get_available_models() -> Dict[str, Any]: + """Return the configured model options.""" + + return { + "stt": AVAILABLE_STT_MODELS, + "tts": AVAILABLE_TTS_MODELS, + } + + +def get_stt_option(model_id: str) -> Optional[Dict[str, Any]]: + return _find_model_option(AVAILABLE_STT_MODELS, model_id) + + +def get_tts_option(model_id: str) -> Optional[Dict[str, Any]]: + return _find_model_option(AVAILABLE_TTS_MODELS, model_id) + + # API Endpoints LLM_API_ENDPOINT = os.getenv("LLM_API_ENDPOINT", "http://127.0.0.1:1234/v1/chat/completions") TTS_API_ENDPOINT = os.getenv("TTS_API_ENDPOINT", "http://localhost:5005/v1/audio/speech") -# Whisper Model Configuration -WHISPER_MODEL = os.getenv("WHISPER_MODEL", "tiny.en") +# Speech-to-Text Configuration +STT_MODEL_ID = os.getenv( + "STT_MODEL_ID", + os.getenv("WHISPER_MODEL", "kyutai/stt-1b-en_fr"), +) +STT_DEVICE = os.getenv("STT_DEVICE") +STT_TORCH_DTYPE = os.getenv("STT_TORCH_DTYPE") +_DEFAULT_STT_OPTION = get_stt_option(os.getenv("WHISPER_MODEL", "kyutai/stt-1b-en_fr")) + +STT_GENERATION_CONFIG = _load_json_env( + "STT_GENERATION_CONFIG", + (_DEFAULT_STT_OPTION or {}).get("generation_config", {"max_new_tokens": 256}), +) # TTS Configuration -TTS_MODEL = os.getenv("TTS_MODEL", "tts-1") -TTS_VOICE = os.getenv("TTS_VOICE", "tara") -TTS_FORMAT = os.getenv("TTS_FORMAT", "wav") +_DEFAULT_TTS_OPTION = get_tts_option(os.getenv("TTS_MODEL", "sesame/csm-1b")) + +TTS_MODEL = os.getenv("TTS_MODEL", (_DEFAULT_TTS_OPTION or {}).get("id", "sesame/csm-1b")) +TTS_VOICE: Optional[str] = os.getenv("TTS_VOICE") +TTS_FORMAT = os.getenv("TTS_FORMAT", (_DEFAULT_TTS_OPTION or {}).get("output_format", "wav")) +TTS_PROVIDER = os.getenv("TTS_PROVIDER", (_DEFAULT_TTS_OPTION or {}).get("provider", "huggingface-local")) +TTS_API_KEY = os.getenv("TTS_API_KEY") +TTS_INFERENCE_PARAMS = _load_json_env("TTS_INFERENCE_PARAMS", {}) +TTS_EXTRA_HEADERS = _load_json_env("TTS_EXTRA_HEADERS", {}) # WebSocket Server Configuration WEBSOCKET_HOST = os.getenv("WEBSOCKET_HOST", "0.0.0.0") @@ -31,7 +153,7 @@ # Audio Processing VAD_THRESHOLD = float(os.getenv("VAD_THRESHOLD", 0.5)) VAD_BUFFER_SIZE = int(os.getenv("VAD_BUFFER_SIZE", 30)) -AUDIO_SAMPLE_RATE = int(os.getenv("AUDIO_SAMPLE_RATE", 48000)) +AUDIO_SAMPLE_RATE = int(os.getenv("AUDIO_SAMPLE_RATE", 16000)) def get_config() -> Dict[str, Any]: """ @@ -43,13 +165,64 @@ def get_config() -> Dict[str, Any]: return { "llm_api_endpoint": LLM_API_ENDPOINT, "tts_api_endpoint": TTS_API_ENDPOINT, - "whisper_model": WHISPER_MODEL, + "stt_model_id": STT_MODEL_ID, + "stt_device": STT_DEVICE, + "stt_torch_dtype": STT_TORCH_DTYPE, + "stt_generation_config": STT_GENERATION_CONFIG, "tts_model": TTS_MODEL, "tts_voice": TTS_VOICE, "tts_format": TTS_FORMAT, + "tts_provider": TTS_PROVIDER, + "tts_api_key": TTS_API_KEY, + "tts_inference_params": TTS_INFERENCE_PARAMS, + "tts_extra_headers": TTS_EXTRA_HEADERS, "websocket_host": WEBSOCKET_HOST, "websocket_port": WEBSOCKET_PORT, "vad_threshold": VAD_THRESHOLD, "vad_buffer_size": VAD_BUFFER_SIZE, "audio_sample_rate": AUDIO_SAMPLE_RATE, + "model_cache_dir": str(MODEL_CACHE_DIR), + "available_models": get_available_models(), } + + +def set_stt_model(model_id: str, generation_config: Optional[Dict[str, Any]] = None) -> None: + """Persist the currently active STT model in module globals.""" + + global STT_MODEL_ID, STT_GENERATION_CONFIG + + STT_MODEL_ID = model_id + + if generation_config is not None: + STT_GENERATION_CONFIG = dict(generation_config) + + +def set_tts_model( + model_id: str, + *, + provider: Optional[str] = None, + voice: Optional[str] = None, + output_format: Optional[str] = None, + inference_params: Optional[Dict[str, Any]] = None, + api_endpoint: Optional[str] = None, +) -> None: + """Persist the currently active TTS model configuration.""" + + global TTS_MODEL, TTS_PROVIDER, TTS_VOICE, TTS_FORMAT, TTS_INFERENCE_PARAMS, TTS_API_ENDPOINT + + TTS_MODEL = model_id + + if provider is not None: + TTS_PROVIDER = provider + + if voice is not None: + TTS_VOICE = voice + + if output_format is not None: + TTS_FORMAT = output_format + + if inference_params is not None: + TTS_INFERENCE_PARAMS = dict(inference_params) + + if api_endpoint is not None: + TTS_API_ENDPOINT = api_endpoint diff --git a/backend/main.py b/backend/main.py index 8574aa8..faa96d4 100644 --- a/backend/main.py +++ b/backend/main.py @@ -4,17 +4,20 @@ FastAPI application entry point. """ +import asyncio import logging import uvicorn -from fastapi import FastAPI, WebSocket, Depends, HTTPException +from fastapi import FastAPI, WebSocket, Depends, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager +from pydantic import BaseModel +from typing import Optional, Dict, Any # Import configuration from . import config # Import services -from .services.transcription import WhisperTranscriber +from .services.transcription import SpeechTranscriber from .services.llm import LLMClient from .services.tts import TTSClient from .services.vision import vision_service @@ -33,8 +36,17 @@ transcription_service = None llm_service = None tts_service = None +service_reload_lock = asyncio.Lock() # Vision service is a singleton already initialized in its module + +class ModelSelection(BaseModel): + stt_model_id: Optional[str] = None + tts_model_id: Optional[str] = None + stt_generation_config: Optional[Dict[str, Any]] = None + tts_inference_params: Optional[Dict[str, Any]] = None + tts_voice: Optional[str] = None + @asynccontextmanager async def lifespan(app: FastAPI): """ @@ -49,9 +61,13 @@ async def lifespan(app: FastAPI): global transcription_service, llm_service, tts_service # Initialize transcription service - transcription_service = WhisperTranscriber( - model_size=cfg["whisper_model"], - sample_rate=cfg["audio_sample_rate"] + transcription_service = SpeechTranscriber( + model_id=cfg["stt_model_id"], + device=cfg.get("stt_device"), + torch_dtype=cfg.get("stt_torch_dtype"), + sample_rate=cfg["audio_sample_rate"], + generation_config=cfg["stt_generation_config"], + cache_dir=str(config.MODEL_CACHE_DIR), ) # Initialize LLM service @@ -64,7 +80,11 @@ async def lifespan(app: FastAPI): api_endpoint=cfg["tts_api_endpoint"], model=cfg["tts_model"], voice=cfg["tts_voice"], - output_format=cfg["tts_format"] + output_format=cfg["tts_format"], + provider=cfg["tts_provider"], + api_key=cfg["tts_api_key"], + inference_params=cfg["tts_inference_params"], + extra_headers=cfg["tts_extra_headers"], ) # Initialize vision service (will download model if not cached) @@ -78,9 +98,9 @@ async def lifespan(app: FastAPI): # Cleanup on shutdown logger.info("Shutting down services...") - # No specific cleanup needed for these services, - # but we could add resource release code here if needed (maybe in a future release lex 31/03/25) - + if tts_service: + tts_service.close() + logger.info("Shutdown complete") # Create FastAPI application @@ -128,7 +148,7 @@ async def health_check(): "vision": vision_service.is_ready() }, "config": { - "whisper_model": config.WHISPER_MODEL, + "stt_model_id": config.STT_MODEL_ID, "tts_voice": config.TTS_VOICE, "websocket_port": config.WEBSOCKET_PORT } @@ -139,7 +159,7 @@ async def get_full_config(): """Get full configuration.""" if not all([transcription_service, llm_service, tts_service]) or not vision_service.is_ready(): raise HTTPException(status_code=503, detail="Services not initialized") - + return { "transcription": transcription_service.get_config(), "llm": llm_service.get_config(), @@ -147,6 +167,129 @@ async def get_full_config(): "system": config.get_config() } + +@app.get("/models") +async def list_models(): + """Return available model options and current selections.""" + + return { + "stt": { + "current": config.STT_MODEL_ID, + "options": config.AVAILABLE_STT_MODELS, + }, + "tts": { + "current": config.TTS_MODEL, + "options": config.AVAILABLE_TTS_MODELS, + }, + } + + +@app.post("/models/select", status_code=status.HTTP_202_ACCEPTED) +async def select_models(selection: ModelSelection): + """Reload STT and/or TTS services with the requested model identifiers.""" + + global transcription_service, tts_service + + updates: Dict[str, Dict[str, Any]] = {} + + async with service_reload_lock: + if selection.stt_model_id: + option = config.get_stt_option(selection.stt_model_id) + if option is None: + raise HTTPException(status_code=400, detail="Unknown STT model") + + raw_generation_config = selection.stt_generation_config or option.get( + "generation_config", + config.STT_GENERATION_CONFIG, + ) + generation_config = dict(raw_generation_config or {}) + + dtype_override = generation_config.pop("torch_dtype", None) + + device = option.get("device", config.STT_DEVICE) + torch_dtype = dtype_override or option.get("torch_dtype", config.STT_TORCH_DTYPE) + + current_config = transcription_service.get_config() if transcription_service else {} + is_same_model = current_config.get("model_id") == option["id"] + is_same_generation = generation_config == config.STT_GENERATION_CONFIG + + if is_same_model and is_same_generation: + logger.info("Requested STT model %s is already active; skipping reload", option["id"]) + else: + old_transcriber = transcription_service + transcription_service = SpeechTranscriber( + model_id=option["id"], + device=device, + torch_dtype=torch_dtype, + sample_rate=option.get("sample_rate", config.AUDIO_SAMPLE_RATE), + generation_config=generation_config, + cache_dir=str(config.MODEL_CACHE_DIR), + ) + + config.set_stt_model(option["id"], generation_config) + + updates["stt"] = transcription_service.get_config() + + if old_transcriber is not None: + del old_transcriber + + if selection.tts_model_id: + option = config.get_tts_option(selection.tts_model_id) + if option is None: + raise HTTPException(status_code=400, detail="Unknown TTS model") + + inference_params = selection.tts_inference_params or option.get( + "inference_params", + config.TTS_INFERENCE_PARAMS, + ) + + provider = option.get("provider", config.TTS_PROVIDER) + output_format = option.get("output_format", config.TTS_FORMAT) + voice = selection.tts_voice or option.get("voice", config.TTS_VOICE) + + api_endpoint = option.get("api_endpoint", config.TTS_API_ENDPOINT) + current_tts_config = tts_service.get_config() if tts_service else {} + is_same_tts_model = current_tts_config.get("model") == option["id"] + is_same_voice = voice == config.TTS_VOICE + is_same_provider = provider == config.TTS_PROVIDER + is_same_format = output_format == config.TTS_FORMAT + is_same_endpoint = api_endpoint == config.TTS_API_ENDPOINT + is_same_params = inference_params == config.TTS_INFERENCE_PARAMS + + if all([is_same_tts_model, is_same_voice, is_same_provider, is_same_format, is_same_endpoint, is_same_params]): + logger.info("Requested TTS model %s is already active; skipping reload", option["id"]) + else: + old_tts = tts_service + tts_service = TTSClient( + api_endpoint=api_endpoint, + model=option["id"], + voice=voice, + output_format=output_format, + provider=provider, + api_key=config.TTS_API_KEY, + inference_params=inference_params, + extra_headers=config.TTS_EXTRA_HEADERS, + ) + + config.set_tts_model( + option["id"], + provider=provider, + voice=voice, + output_format=output_format, + inference_params=inference_params, + api_endpoint=api_endpoint, + ) + + updates["tts"] = tts_service.get_config() + + if old_tts is not None: + old_tts.close() + + if not updates: + raise HTTPException(status_code=400, detail="No model changes requested") + + return {"status": "accepted", "updated": updates} + # WebSocket route @app.websocket("/ws") async def websocket_route(websocket: WebSocket): diff --git a/backend/requirements.txt b/backend/requirements.txt index 9452ce9..0207c3a 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -2,10 +2,11 @@ fastapi==0.109.2 uvicorn==0.27.1 python-dotenv==1.0.1 websockets==12.0 -numpy==1.26.4 +numpy faster-whisper==1.1.1 requests==2.31.0 python-multipart==0.0.9 torch>=2.0.1 ffmpeg-python==0.2.0 transformers>=4.31.0 +huggingface-hub>=0.20.0 diff --git a/backend/routes/websocket.py b/backend/routes/websocket.py index fd5e82e..87f1455 100644 --- a/backend/routes/websocket.py +++ b/backend/routes/websocket.py @@ -15,7 +15,7 @@ from pydantic import BaseModel from datetime import datetime -from ..services.transcription import WhisperTranscriber +from ..services.transcription import SpeechTranscriber from ..services.llm import LLMClient from ..services.tts import TTSClient from ..services.conversation_storage import ConversationStorage @@ -66,7 +66,7 @@ class WebSocketManager: def __init__( self, - transcriber: WhisperTranscriber, + transcriber: SpeechTranscriber, llm_client: LLMClient, tts_client: TTSClient ): @@ -1187,7 +1187,7 @@ async def _handle_vision_file_upload(self, websocket: WebSocket, image_base64: s async def websocket_endpoint( websocket: WebSocket, - transcriber: WhisperTranscriber, + transcriber: SpeechTranscriber, llm_client: LLMClient, tts_client: TTSClient ): diff --git a/backend/services/llm.py b/backend/services/llm.py index dd287ce..75cdd47 100644 --- a/backend/services/llm.py +++ b/backend/services/llm.py @@ -5,10 +5,12 @@ """ import json -import requests import logging +import time from typing import Dict, Any, List, Optional +import requests + # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -90,7 +92,7 @@ def get_response(self, user_input: str, system_prompt: Optional[str] = None, Dictionary containing the LLM response and metadata """ self.is_processing = True - start_time = logging.Formatter.converter() + start_time = time.perf_counter() try: # Prepare messages @@ -164,9 +166,8 @@ def get_response(self, user_input: str, system_prompt: Optional[str] = None, self.add_to_history("assistant", assistant_message) # Calculate processing time - end_time = logging.Formatter.converter() - processing_time = end_time[0] - start_time[0] - + processing_time = time.perf_counter() - start_time + logger.info(f"Received response from LLM API after {processing_time:.2f}s") return { diff --git a/backend/services/transcription.py b/backend/services/transcription.py index eb91c26..e293a66 100644 --- a/backend/services/transcription.py +++ b/backend/services/transcription.py @@ -1,197 +1,188 @@ -""" -Speech-to-Text Transcription Service +"""Speech-to-Text transcription service for local Kyutai models.""" -Uses Faster Whisper to transcribe speech audio. -""" +from __future__ import annotations -import numpy as np import logging -import io # For BytesIO -from typing import Dict, Any, List, Optional, Tuple -from faster_whisper import WhisperModel import time -import torch # For CUDA availability check +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor + -# Configure logging -logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -class WhisperTranscriber: - """ - Speech-to-Text service using Faster Whisper. - - This class handles transcription of speech audio segments. - """ - + +class SpeechTranscriber: + """Run automatic speech recognition with local Hugging Face models.""" + def __init__( self, - model_size: str = "base", - device: str = None, - compute_type: str = None, - beam_size: int = 2, - sample_rate: int = 44100 - ): - """ - Initialize the transcription service. - - Args: - model_size: Whisper model size (tiny.en, base.en, small.en, medium.en, large) - device: Device to run model on ('cpu' or 'cuda'), if None will auto-detect - compute_type: Model computation type (int8, int16, float16, float32), if None will select based on device - beam_size: Beam size for decoding - sample_rate: Audio sample rate in Hz - """ - self.model_size = model_size - - # Auto-detect device if not specified - if device is None: - self.device = "cuda" if torch.cuda.is_available() else "cpu" - else: - self.device = device - - # Select appropriate compute type based on device if not specified - if compute_type is None: - self.compute_type = "float16" if self.device == "cuda" else "int8" - else: - self.compute_type = compute_type - - self.beam_size = beam_size + model_id: str, + device: Optional[str] = None, + torch_dtype: Optional[str] = None, + sample_rate: int = 16_000, + generation_config: Optional[Dict[str, Any]] = None, + cache_dir: Optional[Union[str, Path]] = None, + ) -> None: + self.model_id = model_id self.sample_rate = sample_rate - - # Initialize model - self._initialize_model() - - # State tracking + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self.torch_dtype = self._resolve_dtype(torch_dtype) + self.generation_config = generation_config or {"max_new_tokens": 256} + self.cache_dir = Path(cache_dir).expanduser().resolve() if cache_dir else None + + self.model_source = self._resolve_model_source(self.model_id) + + logger.info( + "Loading speech model %s from %s on %s (dtype=%s)", + self.model_id, + self.model_source, + self.device, + self.torch_dtype, + ) + + load_kwargs: Dict[str, Any] = { + "torch_dtype": self.torch_dtype, + "low_cpu_mem_usage": True, + "trust_remote_code": True, + } + + processor_kwargs: Dict[str, Any] = {"trust_remote_code": True} + + if self.cache_dir: + load_kwargs["cache_dir"] = str(self.cache_dir) + processor_kwargs["cache_dir"] = str(self.cache_dir) + + self.processor = AutoProcessor.from_pretrained( + self.model_source, + **processor_kwargs, + ) + self.model = AutoModelForSpeechSeq2Seq.from_pretrained( + self.model_source, + **load_kwargs, + ) + self.model.to(self.device) + self.model.eval() + self.is_processing = False - - logger.info(f"Initialized Whisper Transcriber with model={model_size}, " - f"device={self.device}, compute_type={self.compute_type}") - - def _initialize_model(self): - """Initialize Whisper model.""" - try: - # Load the model - self.model = WhisperModel( - self.model_size, # Pass as positional argument, not keyword - device=self.device, - compute_type=self.compute_type - ) - logger.info(f"Successfully loaded Whisper model: {self.model_size}") - except Exception as e: - logger.error(f"Failed to load Whisper model: {e}") - raise - - def transcribe(self, audio: np.ndarray) -> Tuple[str, Dict[str, Any]]: - """ - Transcribe audio data to text. - - Args: - audio: Audio data as numpy array - - Returns: - Tuple[str, Dict[str, Any]]: - - Transcribed text - - Dictionary with additional information (confidence, language, etc.) - """ + + def _resolve_model_source(self, model_id: str) -> str: + if not self.cache_dir: + return model_id + + safe_name = model_id.replace("/", "__") + candidate = self.cache_dir / safe_name + + if candidate.exists(): + return str(candidate) + + logger.info( + "Local cache for %s not found at %s; falling back to Hugging Face hub", + model_id, + candidate, + ) + return model_id + + def _resolve_dtype(self, torch_dtype: Optional[str]) -> torch.dtype: + if torch_dtype: + try: + return getattr(torch, torch_dtype) + except AttributeError as exc: # pragma: no cover - defensive path + raise ValueError(f"Unsupported torch dtype: {torch_dtype}") from exc + + if self.device.startswith("cuda"): + return torch.float16 + return torch.float32 + + @staticmethod + def _normalize_audio(audio: np.ndarray) -> np.ndarray: + if audio.dtype == np.uint8: + audio = np.frombuffer(audio.tobytes(), dtype=" 0: + audio /= max_val + + return audio + + def transcribe(self, audio: np.ndarray) -> tuple[str, Dict[str, Any]]: start_time = time.time() self.is_processing = True - + try: - # Handle WAV data (if audio is in uint8 format, it contains WAV headers) - if audio.dtype == np.uint8: - # First check the RIFF header to confirm this is WAV data - header = bytes(audio[:44]) - if header[:4] == b'RIFF' and header[8:12] == b'WAVE': - # Create a file-like object that Whisper can read from - audio_file = io.BytesIO(bytes(audio)) - # The transcribe method expects a file-like object with read method - audio = audio_file - else: - # Not a proper WAV header - logger.warning("Received audio data with incorrect WAV header") - # Attempt to process as raw data - audio = audio.astype(np.float32) / np.max(np.abs(audio)) if np.max(np.abs(audio)) > 0 else audio - else: - # Normalize audio if it's raw float data - audio = audio.astype(np.float32) / np.max(np.abs(audio)) if np.max(np.abs(audio)) > 0 else audio - - # Transcribe - segments, info = self.model.transcribe( - audio, - beam_size=self.beam_size, - language="en", # Force English language - vad_filter=False # Disable VAD filter since we handle it in the frontend + float_audio = self._normalize_audio(audio) + + inputs = self.processor( + audio=float_audio, + sampling_rate=self.sample_rate, + return_tensors="pt", ) - - # Collect all segment texts - text_segments = [segment.text for segment in segments] - full_text = " ".join(text_segments).strip() - - # Calculate processing time + + model_inputs = { + key: value.to(self.device) + for key, value in inputs.items() + if isinstance(value, torch.Tensor) + } + + if "input_features" in model_inputs: + input_features = model_inputs.pop("input_features") + elif "input_values" in model_inputs: + input_features = model_inputs.pop("input_values") + else: + raise ValueError("Processor did not return input features") + + generate_kwargs = dict(self.generation_config) + + with torch.no_grad(): + generated_ids = self.model.generate( + input_features, + **generate_kwargs, + ) + + transcript = self.processor.batch_decode( + generated_ids, + skip_special_tokens=True, + )[0] + processing_time = time.time() - start_time - logger.info(f"Transcription completed in {processing_time:.2f}s: {full_text[:50]}...") - + metadata = { - "confidence": getattr(info, "avg_logprob", 0), - "language": getattr(info, "language", "en"), "processing_time": processing_time, - "segments_count": len(text_segments) + "model_id": self.model_id, + "device": self.device, + "dtype": str(self.torch_dtype).split(".")[-1], + "num_tokens": int(generated_ids.shape[-1]) if generated_ids.ndim > 1 else 0, } - - return full_text, metadata - - except Exception as e: - logger.error(f"Transcription error: {e}") - return "", {"error": str(e)} - finally: - self.is_processing = False - - def transcribe_streaming(self, audio_generator): - """ - Stream transcription results from an audio generator. - - Args: - audio_generator: Generator yielding audio chunks - - Yields: - Partial transcription results as they become available - """ - self.is_processing = True - - try: - # Process the streaming transcription - segments = self.model.transcribe_with_vad( - audio_generator, - language="en" - ) - - # Yield each segment as it's transcribed - for segment in segments: - yield { - "text": segment.text, - "start": segment.start, - "end": segment.end, - "confidence": segment.avg_logprob - } - - except Exception as e: - logger.error(f"Streaming transcription error: {e}") - yield {"error": str(e)} + + return transcript.strip(), metadata + + except Exception as exc: + logger.error("Transcription error: %s", exc, exc_info=True) + return "", {"error": str(exc)} + finally: self.is_processing = False - + + def transcribe_streaming(self, audio_generator: Any): + raise NotImplementedError("Streaming transcription is not yet implemented for Kyutai models") + def get_config(self) -> Dict[str, Any]: - """ - Get the current configuration. - - Returns: - Dict containing the current configuration - """ return { - "model_size": self.model_size, + "model_id": self.model_id, + "model_path": self.model_source, "device": self.device, - "compute_type": self.compute_type, - "beam_size": self.beam_size, + "torch_dtype": str(self.torch_dtype).split(".")[-1], "sample_rate": self.sample_rate, - "is_processing": self.is_processing + "generation_config": self.generation_config, + "is_processing": self.is_processing, + "cache_dir": str(self.cache_dir) if self.cache_dir else None, } diff --git a/backend/services/tts.py b/backend/services/tts.py index 30abb66..4b71f44 100644 --- a/backend/services/tts.py +++ b/backend/services/tts.py @@ -1,52 +1,37 @@ -""" -Text-to-Speech Service - -Handles communication with the local TTS API endpoint. -""" +"""Text-to-Speech Service utilities.""" +import asyncio +import base64 import json -import requests import logging -import io import time -import base64 -import asyncio -from typing import Dict, Any, List, Optional, BinaryIO, Generator, AsyncGenerator +from typing import Any, Dict, Generator, Optional + +import requests # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TTSClient: - """ - Client for communicating with a local TTS API. - - This class handles requests to a locally hosted TTS API that follows - the OpenAI API format for text-to-speech generation. - """ - + """Client for communicating with configurable TTS providers.""" + def __init__( self, api_endpoint: str = "http://localhost:5005/v1/audio/speech", model: str = "tts-1", - voice: str = "tara", + voice: Optional[str] = None, output_format: str = "wav", speed: float = 1.0, timeout: int = 60, - chunk_size: int = 4096 + chunk_size: int = 4096, + provider: str = "openai-compatible", + api_key: Optional[str] = None, + inference_params: Optional[Dict[str, Any]] = None, + extra_headers: Optional[Dict[str, str]] = None, ): - """ - Initialize the TTS client. - - Args: - api_endpoint: URL of the local TTS API - model: TTS model name to use - voice: Voice to use for synthesis - output_format: Output audio format (mp3, opus, aac, flac) - speed: Speech speed multiplier (0.25 to 4.0) - timeout: Request timeout in seconds - chunk_size: Size of audio chunks to stream in bytes - """ + """Initialize the TTS client.""" + self.api_endpoint = api_endpoint self.model = model self.voice = voice @@ -54,13 +39,184 @@ def __init__( self.speed = speed self.timeout = timeout self.chunk_size = chunk_size - + self.provider = provider + self.api_key = api_key + self.inference_params = inference_params or {} + self.extra_headers = extra_headers or {} + + # Persistent HTTP session for keep-alive reuse. + self.session = requests.Session() + self.session.headers.update(self._build_default_headers()) + # State tracking self.is_processing = False self.last_processing_time = 0 - - logger.info(f"Initialized TTS Client with endpoint={api_endpoint}, " - f"model={model}, voice={voice}") + + logger.info( + "Initialized TTS Client", + extra={ + "endpoint": api_endpoint, + "model": model, + "voice": voice, + "provider": provider, + }, + ) + + def _build_default_headers(self) -> Dict[str, str]: + """Build the default headers for HTTP requests.""" + + headers: Dict[str, str] = {"Content-Type": "application/json"} + + if self.api_key: + headers.setdefault("Authorization", f"Bearer {self.api_key}") + + if self.provider.startswith("huggingface"): + # Hugging Face TTS endpoints commonly return binary audio. + headers.setdefault("Accept", "application/octet-stream") + + if self.extra_headers: + headers.update(self.extra_headers) + + return headers + + def _build_payload(self, text: str) -> Dict[str, Any]: + """Construct a provider-specific request payload.""" + + if self.provider.startswith("huggingface"): + parameters: Dict[str, Any] = dict(self.inference_params) + + if self.voice: + parameters.setdefault("voice", self.voice) + + if self.output_format: + parameters.setdefault("format", self.output_format) + + if self.speed and self.speed != 1.0: + parameters.setdefault("speed", self.speed) + + payload = { + "inputs": text, + "parameters": {k: v for k, v in parameters.items() if v is not None}, + } + + if self.model: + payload.setdefault("model", self.model) + + return payload + + payload = { + "model": self.model, + "input": text, + "response_format": self.output_format, + } + + if self.voice is not None: + payload["voice"] = self.voice + + if self.speed is not None: + payload["speed"] = self.speed + + return payload + + def _decode_audio_payload(self, payload: Any) -> Optional[bytes]: + """Extract audio bytes from a JSON payload.""" + + if payload is None: + return None + + if isinstance(payload, dict): + candidates = [ + payload.get("audio"), + payload.get("audio_base64"), + payload.get("generated_audio"), + payload.get("b64_audio"), + payload.get("data"), + ] + + for candidate in candidates: + audio_bytes = self._decode_audio_payload(candidate) + if audio_bytes: + return audio_bytes + + if isinstance(payload, (list, tuple)): + for item in payload: + audio_bytes = self._decode_audio_payload(item) + if audio_bytes: + return audio_bytes + + if isinstance(payload, (bytes, bytearray)): + return bytes(payload) + + if isinstance(payload, str): + try: + return base64.b64decode(payload) + except (ValueError, TypeError): + return None + + return None + + def _extract_audio_response(self, response: requests.Response) -> bytes: + """Return raw audio bytes from a HTTP response.""" + + content_type = response.headers.get("content-type", "") + + if "application/json" in content_type: + json_payload = response.json() + audio_bytes = self._decode_audio_payload(json_payload) + + if audio_bytes is None: + raise ValueError("TTS response did not include audio data") + + return audio_bytes + + return response.content + + def _iter_stream_chunks(self, response: requests.Response) -> Generator[bytes, None, None]: + """Yield audio chunks from a streaming HTTP response.""" + + content_type = response.headers.get("content-type", "") + + if "text/event-stream" in content_type: + for line in response.iter_lines(decode_unicode=False): + if not line: + continue + + if line.startswith(b"data:"): + data = line[len(b"data:") :].strip() + + if data in (b"[DONE]", b""): + continue + + try: + payload = json.loads(data.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + logger.debug("Skipping non-JSON SSE payload: %s", data) + continue + + audio_bytes = self._decode_audio_payload(payload) + + if audio_bytes: + yield audio_bytes + + return + + transfer_encoding = response.headers.get("transfer-encoding", "").lower() + + if "chunked" in transfer_encoding: + for chunk in response.iter_content(chunk_size=self.chunk_size): + if chunk: + yield chunk + return + + # Fallback: non-streaming response, split for compatibility. + audio_data = self._extract_audio_response(response) + + total_chunks = (len(audio_data) + self.chunk_size - 1) // self.chunk_size + + for index in range(total_chunks): + start_idx = index * self.chunk_size + end_idx = min(start_idx + self.chunk_size, len(audio_data)) + yield audio_data[start_idx:end_idx] def text_to_speech(self, text: str) -> bytes: """ @@ -76,30 +232,27 @@ def text_to_speech(self, text: str) -> bytes: start_time = time.time() try: - # Prepare request payload - payload = { - "model": self.model, - "input": text, - "voice": self.voice, - "response_format": self.output_format, - "speed": self.speed - } - - logger.info(f"Sending TTS request with {len(text)} characters of text") - - # Send request to TTS API - response = requests.post( + payload = self._build_payload(text) + + logger.info( + "Sending TTS request", + extra={ + "chars": len(text), + "provider": self.provider, + "endpoint": self.api_endpoint, + }, + ) + + response = self.session.post( self.api_endpoint, json=payload, - timeout=self.timeout + timeout=self.timeout, ) - - # Check if request was successful + response.raise_for_status() - - # Get audio content - audio_data = response.content - + + audio_data = self._extract_audio_response(response) + # Calculate processing time self.last_processing_time = time.time() - start_time @@ -131,45 +284,29 @@ def stream_text_to_speech(self, text: str) -> Generator[bytes, None, None]: start_time = time.time() try: - # Prepare request payload - payload = { - "model": self.model, - "input": text, - "voice": self.voice, - "response_format": self.output_format, - "speed": self.speed - } - - logger.info(f"Sending streaming TTS request with {len(text)} characters of text") - - # Send request to TTS API - with requests.post( + payload = self._build_payload(text) + + logger.info( + "Sending streaming TTS request", + extra={ + "chars": len(text), + "provider": self.provider, + "endpoint": self.api_endpoint, + }, + ) + + with self.session.post( self.api_endpoint, json=payload, timeout=self.timeout, - stream=True + stream=True, ) as response: response.raise_for_status() - - # Check if streaming is supported by the API - is_chunked = response.headers.get('transfer-encoding', '') == 'chunked' - - if is_chunked: - # The API supports streaming - for chunk in response.iter_content(chunk_size=self.chunk_size): - if chunk: - yield chunk - else: - # The API doesn't support streaming, but we'll fake it by - # splitting the response into chunks - audio_data = response.content - total_chunks = (len(audio_data) + self.chunk_size - 1) // self.chunk_size - - for i in range(total_chunks): - start_idx = i * self.chunk_size - end_idx = min(start_idx + self.chunk_size, len(audio_data)) - yield audio_data[start_idx:end_idx] - + + for chunk in self._iter_stream_chunks(response): + if chunk: + yield chunk + # Calculate processing time self.last_processing_time = time.time() - start_time logger.info(f"Completed TTS streaming after {self.last_processing_time:.2f}s") @@ -207,6 +344,11 @@ async def async_text_to_speech(self, text: str) -> bytes: raise finally: self.is_processing = False + + def close(self) -> None: + """Close the underlying HTTP session.""" + + self.session.close() def get_config(self) -> Dict[str, Any]: """ diff --git a/frontend/src/components/PreferencesModal.tsx b/frontend/src/components/PreferencesModal.tsx index 112a2e1..4568808 100644 --- a/frontend/src/components/PreferencesModal.tsx +++ b/frontend/src/components/PreferencesModal.tsx @@ -1,5 +1,5 @@ -import React, { useState, useEffect } from 'react'; -import { User, Sparkles, Eye } from 'lucide-react'; +import React, { useState, useEffect, useCallback } from 'react'; +import { User, Sparkles, Eye, SlidersHorizontal } from 'lucide-react'; import websocketService, { MessageType } from '../services/websocket'; interface PreferencesModalProps { @@ -7,13 +7,64 @@ interface PreferencesModalProps { onClose: () => void; } +interface ModelOption { + id: string; + label: string; + description?: string; +} + const PreferencesModal: React.FC = ({ isOpen, onClose }) => { const [systemPrompt, setSystemPrompt] = useState(''); const [userName, setUserName] = useState(''); const [isSaving, setIsSaving] = useState(false); const [saveError, setSaveError] = useState(null); - const [activeTab, setActiveTab] = useState<'profile' | 'system'>('profile'); + const [activeTab, setActiveTab] = useState<'profile' | 'system' | 'models'>('profile'); const [isVisionEnabled, setIsVisionEnabled] = useState(false); + const [sttOptions, setSttOptions] = useState([]); + const [ttsOptions, setTtsOptions] = useState([]); + const [currentSttModel, setCurrentSttModel] = useState(''); + const [currentTtsModel, setCurrentTtsModel] = useState(''); + const [selectedSttModel, setSelectedSttModel] = useState(''); + const [selectedTtsModel, setSelectedTtsModel] = useState(''); + const [modelsError, setModelsError] = useState(null); + const [modelsLoading, setModelsLoading] = useState(false); + const apiBaseUrl = (import.meta.env.VITE_API_BASE_URL as string | undefined) ?? 'http://localhost:8000'; + + const fetchModelOptions = useCallback(async () => { + setModelsLoading(true); + setModelsError(null); + + try { + const response = await fetch(`${apiBaseUrl}/models`); + + if (!response.ok) { + throw new Error(`Failed to load models: ${response.status}`); + } + + const data = await response.json(); + const stt = data?.stt ?? {}; + const tts = data?.tts ?? {}; + + const sttList: ModelOption[] = Array.isArray(stt.options) ? stt.options : []; + const ttsList: ModelOption[] = Array.isArray(tts.options) ? tts.options : []; + + setSttOptions(sttList); + setTtsOptions(ttsList); + + const sttCurrent = typeof stt.current === 'string' ? stt.current : ''; + const ttsCurrent = typeof tts.current === 'string' ? tts.current : ''; + + setCurrentSttModel(sttCurrent); + setCurrentTtsModel(ttsCurrent); + setSelectedSttModel(sttCurrent); + setSelectedTtsModel(ttsCurrent); + } catch (error) { + console.error('Error loading model options', error); + setModelsError('Failed to load available models. Ensure the backend is running.'); + } finally { + setModelsLoading(false); + } + }, [apiBaseUrl]); useEffect(() => { if (isOpen) { @@ -42,25 +93,31 @@ const PreferencesModal: React.FC = ({ isOpen, onClose }) // Listen for responses websocketService.addEventListener(MessageType.SYSTEM_PROMPT, handleSystemPrompt); websocketService.addEventListener(MessageType.USER_PROFILE, handleUserProfile); - websocketService.addEventListener(MessageType.VISION_SETTINGS as any, handleVisionSettings); + websocketService.addEventListener(MessageType.VISION_SETTINGS, handleVisionSettings); // Request data websocketService.getSystemPrompt(); websocketService.getUserProfile(); websocketService.getVisionSettings(); - + + fetchModelOptions(); + console.log('Requested preferences data'); return () => { websocketService.removeEventListener(MessageType.SYSTEM_PROMPT, handleSystemPrompt); websocketService.removeEventListener(MessageType.USER_PROFILE, handleUserProfile); - websocketService.removeEventListener(MessageType.VISION_SETTINGS as any, handleVisionSettings); + websocketService.removeEventListener(MessageType.VISION_SETTINGS, handleVisionSettings); }; } - }, [isOpen]); + }, [isOpen, fetchModelOptions]); // Listen for update confirmations useEffect(() => { + if (activeTab === 'models') { + return; + } + let updateCount = 0; const expectedUpdateCount = 3; // Always expect 3 updates: system prompt, user profile, and vision let success = true; @@ -115,25 +172,86 @@ const PreferencesModal: React.FC = ({ isOpen, onClose }) websocketService.addEventListener(MessageType.SYSTEM_PROMPT_UPDATED, handlePromptUpdated); websocketService.addEventListener(MessageType.USER_PROFILE_UPDATED, handleProfileUpdated); - websocketService.addEventListener(MessageType.VISION_SETTINGS_UPDATED as any, handleVisionSettingsUpdated); + websocketService.addEventListener(MessageType.VISION_SETTINGS_UPDATED, handleVisionSettingsUpdated); return () => { websocketService.removeEventListener(MessageType.SYSTEM_PROMPT_UPDATED, handlePromptUpdated); websocketService.removeEventListener(MessageType.USER_PROFILE_UPDATED, handleProfileUpdated); - websocketService.removeEventListener(MessageType.VISION_SETTINGS_UPDATED as any, handleVisionSettingsUpdated); + websocketService.removeEventListener(MessageType.VISION_SETTINGS_UPDATED, handleVisionSettingsUpdated); }; }, [onClose, activeTab]); - const handleSave = () => { + const handleSave = async () => { + if (activeTab === 'models') { + setSaveError(null); + setModelsError(null); + + const payload: Record = {}; + + if (selectedSttModel && selectedSttModel !== currentSttModel) { + payload.stt_model_id = selectedSttModel; + } + + if (selectedTtsModel && selectedTtsModel !== currentTtsModel) { + payload.tts_model_id = selectedTtsModel; + } + + if (Object.keys(payload).length === 0) { + onClose(); + return; + } + + setIsSaving(true); + + try { + const response = await fetch(`${apiBaseUrl}/models/select`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(payload) + }); + + if (!response.ok) { + const message = await response.text(); + throw new Error(message || 'Failed to update models'); + } + + await response.json(); + + if (payload.stt_model_id) { + const sttId = payload.stt_model_id as string; + setCurrentSttModel(sttId); + setSelectedSttModel(sttId); + } + + if (payload.tts_model_id) { + const ttsId = payload.tts_model_id as string; + setCurrentTtsModel(ttsId); + setSelectedTtsModel(ttsId); + } + + await fetchModelOptions(); + onClose(); + } catch (error) { + console.error('Failed to update models', error); + setModelsError('Failed to update models. Check backend logs for more details.'); + } finally { + setIsSaving(false); + } + + return; + } + // Check if system prompt is empty when in system tab if (activeTab === 'system' && !systemPrompt.trim()) { setSaveError('System prompt cannot be empty'); return; } - + setIsSaving(true); setSaveError(null); - + // Always update all settings websocketService.updateSystemPrompt(systemPrompt); websocketService.updateUserProfile(userName); @@ -157,41 +275,6 @@ const PreferencesModal: React.FC = ({ isOpen, onClose }) ); }; - // Tab rendering helpers - const renderVisionTab = () => ( -
-
- -
-
setIsVisionEnabled(!isVisionEnabled)} - > - -
- - {isVisionEnabled ? 'Enabled' : 'Disabled'} - -
-

- When enabled, Vocalis will use computer vision to analyze images and provide visual context to your conversations. -

-
-

- Coming Soon: Vision capabilities will allow Vocalis to see and describe images, - analyze documents, interpret charts, and provide visual assistance during your conversations. -

-
-
-
- ); - const renderProfileTab = () => (
@@ -234,6 +317,60 @@ const PreferencesModal: React.FC = ({ isOpen, onClose })

); + + const renderModelsTab = () => { + const sttDetail = sttOptions.find((option) => option.id === selectedSttModel); + const ttsDetail = ttsOptions.find((option) => option.id === selectedTtsModel); + + return ( +
+
+ + +

+ {sttDetail?.description ?? 'Select the speech recognition model used for live transcription.'} +

+
+ +
+ + +

+ {ttsDetail?.description ?? 'Choose the voice synthesis engine used for playback.'} +

+
+ +
+

+ Model changes reload the backend pipelines and may take a few moments. Sessions in progress should be paused before + switching for best results. +

+
+
+ ); + }; // Handle animation state const [isVisible, setIsVisible] = useState(false); @@ -280,19 +417,40 @@ const PreferencesModal: React.FC = ({ isOpen, onClose }) System Prompt +
- + {/* Content */}
- {activeTab === 'profile' - ? renderProfileTab() - : renderSystemTab()} - - {saveError && ( + {activeTab === 'profile' && renderProfileTab()} + {activeTab === 'system' && renderSystemTab()} + {activeTab === 'models' && renderModelsTab()} + + {activeTab !== 'models' && saveError && (
{saveError}
)} + + {activeTab === 'models' && modelsError && ( +
+ {modelsError} +
+ )} + + {activeTab === 'models' && modelsLoading && ( +
Loading available models…
+ )}
{/* Vision Settings Section */} diff --git a/frontend/src/services/websocket.ts b/frontend/src/services/websocket.ts index 9a6dab7..a96bb7e 100644 --- a/frontend/src/services/websocket.ts +++ b/frontend/src/services/websocket.ts @@ -62,33 +62,7 @@ export interface Session { } // Event types -type WebSocketEventType = - | 'open' - | 'close' - | 'error' - | 'audio' - | 'transcription' - | 'llm_response' - | 'tts_start' - | 'tts_chunk' - | 'tts_end' - | 'status' - | 'ping' - | 'pong' - | 'error' - | 'system_prompt' - | 'system_prompt_updated' - | 'user_profile' - | 'user_profile_updated' - | 'save_session_result' - | 'load_session_result' - | 'list_sessions_result' - | 'delete_session_result' - | 'vision_settings' - | 'vision_settings_updated' - | 'vision_file_upload_result' - | 'vision_processing' - | 'vision_ready'; +type WebSocketEventType = MessageType | 'open' | 'close' | 'error'; // WebSocket state export enum ConnectionState { diff --git a/run.sh b/run.sh index f9a5ac7..70598ca 100644 --- a/run.sh +++ b/run.sh @@ -3,6 +3,17 @@ echo "=== Starting Vocalis ===" +if [ -f "./env/bin/activate" ]; then + echo "=== Ensuring required speech models are available ===" + source ./env/bin/activate + python scripts/download_models.py || echo "[WARN] Some models failed to download. Check your Hugging Face credentials." + if command -v deactivate >/dev/null 2>&1; then + deactivate + fi +else + echo "[WARN] Python virtual environment not found. Run ./setup.sh before starting services." +fi + # Determine which terminal command to use based on OS and available commands terminal_cmd="" if [ "$(uname)" == "Darwin" ]; then diff --git a/scripts/download_models.py b/scripts/download_models.py new file mode 100755 index 0000000..81f7bb1 --- /dev/null +++ b/scripts/download_models.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +"""Download and cache Vocalis speech models from Hugging Face.""" + +from __future__ import annotations + +import argparse +import logging +import sys +from pathlib import Path +from typing import Iterable, Dict, Any, Set + +from huggingface_hub import snapshot_download + +REPO_ROOT = Path(__file__).resolve().parent.parent +BACKEND_PATH = REPO_ROOT / "backend" + +if str(BACKEND_PATH) not in sys.path: + sys.path.insert(0, str(BACKEND_PATH)) + +from backend import config # noqa: E402 + +LOGGER = logging.getLogger("download_models") +logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s") + + +def _safe_directory_name(model_id: str) -> str: + return model_id.replace("/", "__") + + +def _collect_models(include_stt: bool, include_tts: bool) -> Iterable[Dict[str, Any]]: + if include_stt: + yield from config.AVAILABLE_STT_MODELS + if include_tts: + yield from config.AVAILABLE_TTS_MODELS + + +def download_model(model_id: str, destination: Path) -> None: + LOGGER.info("Ensuring model %s is available at %s", model_id, destination) + + destination.mkdir(parents=True, exist_ok=True) + + snapshot_download( + repo_id=model_id, + local_dir=str(destination), + local_dir_use_symlinks=False, + resume_download=True, + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description="Download Kyutai/Seseme models for local inference") + parser.add_argument( + "--cache-dir", + default=str(config.MODEL_CACHE_DIR), + help="Directory where models should be stored (default: %(default)s)", + ) + parser.add_argument( + "--stt", + action="store_true", + help="Download only STT models", + ) + parser.add_argument( + "--tts", + action="store_true", + help="Download only TTS models", + ) + parser.add_argument( + "--model", + action="append", + default=[], + help="Additional model repository IDs to download", + ) + + args = parser.parse_args() + + cache_dir = Path(args.cache_dir).expanduser().resolve() + cache_dir.mkdir(parents=True, exist_ok=True) + + include_stt = args.stt or not (args.stt or args.tts) + include_tts = args.tts or not (args.stt or args.tts) + + models: Set[str] = set(args.model or []) + + for option in _collect_models(include_stt, include_tts): + model_id = option.get("id") + if model_id: + models.add(model_id) + + if not models: + LOGGER.warning("No models requested. Nothing to do.") + return 0 + + errors = False + + for model_id in sorted(models): + target_dir = cache_dir / _safe_directory_name(model_id) + try: + download_model(model_id, target_dir) + except Exception as exc: # pragma: no cover - network/auth errors + errors = True + LOGGER.error("Failed to download %s: %s", model_id, exc) + + if errors: + LOGGER.error("One or more models failed to download. Check authentication and retry.") + return 1 + + LOGGER.info("All requested models are available in %s", cache_dir) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())