Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions backend/backends/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def _generate_sync():
class PyTorchSTTBackend:
"""PyTorch-based STT backend using Whisper."""

def __init__(self, model_size: str = "base"):
def __init__(self, model_size: str = "turbo"):
self.model = None
self.processor = None
self.model_size = model_size
Expand Down Expand Up @@ -379,7 +379,10 @@ def _is_model_cached(self, model_size: str) -> bool:
"""
try:
from huggingface_hub import constants as hf_constants
model_name = f"openai/whisper-{model_size}"
model_size_to_hf = {
"turbo": "openai/whisper-large-v3-turbo",
}
model_name = model_size_to_hf.get(model_size, f"openai/whisper-{model_size}")
repo_cache = Path(hf_constants.HF_HUB_CACHE) / ("models--" + model_name.replace("/", "--"))

if not repo_cache.exists():
Expand Down Expand Up @@ -457,7 +460,10 @@ def _load_model_sync(self, model_size: str):
# Import transformers
from transformers import WhisperProcessor, WhisperForConditionalGeneration

model_name = f"openai/whisper-{model_size}"
model_size_to_hf = {
"turbo": "openai/whisper-large-v3-turbo",
}
model_name = model_size_to_hf.get(model_size, f"openai/whisper-{model_size}")
print(f"[DEBUG] Model name: {model_name}")

print(f"Loading Whisper model {model_size} on {self.device}...")
Expand Down Expand Up @@ -546,21 +552,20 @@ def _transcribe_sync():
)
inputs = inputs.to(self.device)

# Set language if provided
forced_decoder_ids = None
# Generate transcription
# If language is provided, force it; otherwise let Whisper auto-detect
generate_kwargs = {}
if language:
# Support all languages from frontend: en, zh, ja, ko, de, fr, ru, pt, es, it
# Whisper supports these and many more
forced_decoder_ids = self.processor.get_decoder_prompt_ids(
language=language,
task="transcribe",
)
generate_kwargs["forced_decoder_ids"] = forced_decoder_ids

# Generate transcription
with torch.no_grad():
predicted_ids = self.model.generate(
inputs["input_features"],
forced_decoder_ids=forced_decoder_ids,
**generate_kwargs,
)

# Decode
Expand Down
34 changes: 30 additions & 4 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,20 @@
from pathlib import Path
import uuid
import asyncio
import signal
import os

# Set HSA_OVERRIDE_GFX_VERSION for AMD GPUs that aren't officially listed in ROCm
# (e.g., RX 6600 is gfx1032 which maps to gfx1030 target)
# This must be set BEFORE any torch.cuda calls
if not os.environ.get("HSA_OVERRIDE_GFX_VERSION"):
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "10.3.0"

# Suppress noisy MIOpen workspace warnings on AMD GPUs
if not os.environ.get("MIOPEN_LOG_LEVEL"):
os.environ["MIOPEN_LOG_LEVEL"] = "4"

import signal

from . import database, models, profiles, history, tts, transcribe, config, export_import, channels, stories, __version__
from .database import get_db, Generation as DBGeneration, VoiceProfile as DBVoiceProfile
from .utils.progress import get_progress_manager
Expand Down Expand Up @@ -816,9 +827,12 @@ async def transcribe_audio(
# Transcribe
whisper_model = transcribe.get_whisper_model()

# Check if Whisper model is downloaded (uses default size "base")
# Check if Whisper model is downloaded
model_size = whisper_model.model_size
model_name = f"openai/whisper-{model_size}"
model_size_to_hf = {
"turbo": "openai/whisper-large-v3-turbo",
}
model_name = model_size_to_hf.get(model_size, f"openai/whisper-{model_size}")

# Check if model is cached
from huggingface_hub import constants as hf_constants
Expand Down Expand Up @@ -1248,6 +1262,13 @@ def check_whisper_loaded(model_size: str):
"model_size": "large",
"check_loaded": lambda: check_whisper_loaded("large"),
},
{
"model_name": "whisper-turbo",
"display_name": "Whisper Turbo",
"hf_repo_id": "openai/whisper-large-v3-turbo",
"model_size": "turbo",
"check_loaded": lambda: check_whisper_loaded("turbo"),
},
]

# Build a mapping of model_name -> hf_repo_id so we can check if shared repos are downloading
Expand Down Expand Up @@ -1642,7 +1663,12 @@ def _get_gpu_status() -> str:
"""Get GPU availability status."""
backend_type = get_backend_type()
if torch.cuda.is_available():
return f"CUDA ({torch.cuda.get_device_name(0)})"
device_name = torch.cuda.get_device_name(0)
# Check if this is ROCm (AMD) or CUDA (NVIDIA)
is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None
if is_rocm:
return f"ROCm ({device_name})"
return f"CUDA ({device_name})"
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
return "MPS (Apple Silicon)"
elif backend_type == "mlx":
Expand Down
5 changes: 3 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
"landing"
],
"scripts": {
"dev": "bun run setup:dev && cd tauri && bun run tauri dev",
"dev": "cd tauri && bun run dev",
"dev:tauri": "bun run setup:dev && cd tauri && bun run tauri dev",
"dev:web": "cd web && bun run dev",
"dev:landing": "cd landing && bun run dev",
"dev:server": "uvicorn backend.main:app --reload --port 17493",
Expand Down Expand Up @@ -41,4 +42,4 @@
"bun": ">=1.0.0"
},
"packageManager": "bun@1.3.8"
}
}
Loading