Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion app/src/components/Generation/FloatingGenerateBox.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
} from '@/components/ui/select';
import { Textarea } from '@/components/ui/textarea';
import { useToast } from '@/components/ui/use-toast';
import { LANGUAGE_OPTIONS } from '@/lib/constants/languages';
import { LANGUAGE_OPTIONS, type LanguageCode } from '@/lib/constants/languages';
import { useGenerationForm } from '@/lib/hooks/useGenerationForm';
import { useProfile, useProfiles } from '@/lib/hooks/useProfiles';
import { useAddStoryItem, useStory } from '@/lib/hooks/useStories';
Expand Down Expand Up @@ -112,6 +112,13 @@ export function FloatingGenerateBox({
}
}, [selectedProfileId, profiles, setSelectedProfileId]);

// Sync generation form language with selected profile's language
useEffect(() => {
if (selectedProfile?.language) {
form.setValue('language', selectedProfile.language as LanguageCode);
}
}, [selectedProfile, form]);

// Auto-resize textarea based on content (only when expanded)
useEffect(() => {
if (!isExpanded) {
Expand Down
23 changes: 15 additions & 8 deletions backend/backends/mlx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
from ..utils.tasks import get_task_manager

LANGUAGE_CODE_TO_NAME = {
"zh": "chinese", "en": "english", "ja": "japanese", "ko": "korean",
"de": "german", "fr": "french", "ru": "russian", "pt": "portuguese",
"es": "spanish", "it": "italian",
}


class MLXTTSBackend:
"""MLX-based TTS backend using mlx-audio."""
Expand Down Expand Up @@ -316,24 +322,25 @@ def _generate_sync():
# MLX generate() returns a generator yielding GenerationResult objects
audio_chunks = []
sample_rate = 24000

lang = LANGUAGE_CODE_TO_NAME.get(language, "auto")

# Set seed if provided (MLX uses numpy random)
if seed is not None:
import mlx.core as mx
np.random.seed(seed)
mx.random.seed(seed)

# Extract voice prompt info
ref_audio = voice_prompt.get("ref_audio") or voice_prompt.get("ref_audio_path")
ref_text = voice_prompt.get("ref_text", "")

# Validate that the audio file exists
if ref_audio and not Path(ref_audio).exists():
print(f"Warning: Audio file not found: {ref_audio}")
print("This may be due to a cached voice prompt referencing a deleted temp file.")
print("Regenerating without voice prompt.")
ref_audio = None

# Check if model supports voice cloning via generate method
# MLX API may support ref_audio parameter directly
try:
Expand All @@ -344,23 +351,23 @@ def _generate_sync():
sig = inspect.signature(self.model.generate)
if "ref_audio" in sig.parameters:
# Generate with voice cloning
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text):
for result in self.model.generate(text, ref_audio=ref_audio, ref_text=ref_text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
else:
# Fallback: generate without voice cloning
for result in self.model.generate(text):
for result in self.model.generate(text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
else:
# No voice prompt, generate normally
for result in self.model.generate(text):
for result in self.model.generate(text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate
except Exception as e:
# If voice cloning fails, try without it
print(f"Warning: Voice cloning failed, generating without voice prompt: {e}")
for result in self.model.generate(text):
for result in self.model.generate(text, lang_code=lang):
audio_chunks.append(np.array(result.audio))
sample_rate = result.sample_rate

Expand Down
7 changes: 7 additions & 0 deletions backend/backends/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from ..utils.hf_progress import HFProgressTracker, create_hf_progress_callback
from ..utils.tasks import get_task_manager

LANGUAGE_CODE_TO_NAME = {
"zh": "chinese", "en": "english", "ja": "japanese", "ko": "korean",
"de": "german", "fr": "french", "ru": "russian", "pt": "portuguese",
"es": "spanish", "it": "italian",
}


class PyTorchTTSBackend:
"""PyTorch-based TTS backend using Qwen3-TTS."""
Expand Down Expand Up @@ -335,6 +341,7 @@ def _generate_sync():
wavs, sample_rate = self.model.generate_voice_clone(
text=text,
voice_clone_prompt=voice_prompt,
language=LANGUAGE_CODE_TO_NAME.get(language, "auto"),
instruct=instruct,
)
return wavs[0], sample_rate
Expand Down