From f4b692b7f369fce836460b81785bfc580a1a3bad Mon Sep 17 00:00:00 2001 From: mdabucse Date: Sat, 7 Feb 2026 19:32:39 +0530 Subject: [PATCH 1/3] refactor: extract core library with event callbacks (#34) - Add CorePipeline class with zero Rich dependencies - Implement event/callback system (EventType, PipelineEvent, PipelineResult) - Move modules to core/ (config, llm, youtube, pipeline) - Add backward compatibility shims for old import paths - Fix mypy type errors in pipeline.py - Remove hardcoded API key from .env.example --- .env.example | 2 +- src/yt_study/cli.py | 8 +- src/yt_study/config.py | 217 +-------- src/yt_study/core/__init__.py | 40 ++ src/yt_study/core/config.py | 215 +++++++++ src/yt_study/core/llm/__init__.py | 1 + src/yt_study/core/llm/generator.py | 355 ++++++++++++++ src/yt_study/core/llm/providers.py | 152 ++++++ src/yt_study/core/orchestrator.py | 534 ++++++++++++++++++++ src/yt_study/core/pipeline.py | 427 ++++++++++++++++ src/yt_study/core/prompts/__init__.py | 1 + src/yt_study/core/prompts/chapter_notes.py | 55 +++ src/yt_study/core/prompts/study_notes.py | 106 ++++ src/yt_study/core/setup_wizard.py | 420 ++++++++++++++++ src/yt_study/core/youtube/__init__.py | 1 + src/yt_study/core/youtube/metadata.py | 163 +++++++ src/yt_study/core/youtube/parser.py | 117 +++++ src/yt_study/core/youtube/playlist.py | 97 ++++ src/yt_study/core/youtube/transcript.py | 297 ++++++++++++ src/yt_study/llm/__init__.py | 4 +- src/yt_study/llm/generator.py | 356 +------------- src/yt_study/llm/providers.py | 153 +----- src/yt_study/pipeline/__init__.py | 7 +- src/yt_study/pipeline/orchestrator.py | 536 +-------------------- src/yt_study/setup_wizard.py | 422 +--------------- src/yt_study/youtube/__init__.py | 4 +- src/yt_study/youtube/metadata.py | 164 +------ src/yt_study/youtube/parser.py | 118 +---- src/yt_study/youtube/playlist.py | 98 +--- src/yt_study/youtube/transcript.py | 298 +----------- uv.lock | 4 +- 31 files changed, 3018 insertions(+), 2354 deletions(-) create mode 100644 src/yt_study/core/__init__.py create mode 100644 src/yt_study/core/config.py create mode 100644 src/yt_study/core/llm/__init__.py create mode 100644 src/yt_study/core/llm/generator.py create mode 100644 src/yt_study/core/llm/providers.py create mode 100644 src/yt_study/core/orchestrator.py create mode 100644 src/yt_study/core/pipeline.py create mode 100644 src/yt_study/core/prompts/__init__.py create mode 100644 src/yt_study/core/prompts/chapter_notes.py create mode 100644 src/yt_study/core/prompts/study_notes.py create mode 100644 src/yt_study/core/setup_wizard.py create mode 100644 src/yt_study/core/youtube/__init__.py create mode 100644 src/yt_study/core/youtube/metadata.py create mode 100644 src/yt_study/core/youtube/parser.py create mode 100644 src/yt_study/core/youtube/playlist.py create mode 100644 src/yt_study/core/youtube/transcript.py diff --git a/.env.example b/.env.example index f82d014..3f855fb 100644 --- a/.env.example +++ b/.env.example @@ -2,7 +2,7 @@ # Copy this to .env and fill in your API keys # Required: At least one LLM API key (Gemini is the default provider) -GEMINI_API_KEY=your_gemini_key_here +GEMINI_API_KEY="your_gemini_api_key_here" # Optional: Other LLM Providers # OPENAI_API_KEY=your_openai_key_here diff --git a/src/yt_study/cli.py b/src/yt_study/cli.py index 622e8fc..b612fc5 100644 --- a/src/yt_study/cli.py +++ b/src/yt_study/cli.py @@ -78,7 +78,7 @@ def ensure_setup() -> None: "\n[yellow]⚠ No configuration found. Running setup wizard...[/yellow]\n" ) try: - from .setup_wizard import run_setup_wizard + from .core.setup_wizard import run_setup_wizard run_setup_wizard(force=False) except ImportError as e: @@ -177,8 +177,8 @@ def process( try: # Lazy import for faster CLI startup - from .config import config - from .pipeline.orchestrator import PipelineOrchestrator + from .core.config import config + from .core.orchestrator import PipelineOrchestrator # Use config values as defaults, allow CLI overrides selected_model = model or config.default_model @@ -283,7 +283,7 @@ def setup( Runs a wizard to generate the [bold]~/.yt-study/config.env[/bold] file. """ try: - from .setup_wizard import run_setup_wizard + from .core.setup_wizard import run_setup_wizard run_setup_wizard(force=force) except ImportError as e: diff --git a/src/yt_study/config.py b/src/yt_study/config.py index 315c91a..bc6a272 100644 --- a/src/yt_study/config.py +++ b/src/yt_study/config.py @@ -1,215 +1,4 @@ -"""Configuration management for yt-study.""" +"""Backward compatibility - config moved to core.config.""" -import logging -import os -from dataclasses import dataclass, field -from pathlib import Path - - -logger = logging.getLogger(__name__) - - -@dataclass -class Config: - """ - Global configuration for the application. - - Manages loading settings from environment variables and config files. - """ - - # LLM Configuration - default_model: str = "gemini/gemini-2.0-flash" - gemini_api_key: str | None = None - openai_api_key: str | None = None - anthropic_api_key: str | None = None - groq_api_key: str | None = None - xai_api_key: str | None = None - mistral_api_key: str | None = None - - # LLM Generation Parameters - temperature: float = 0.7 - max_tokens: int | None = None - - # Chunking Configuration - chunk_size: int = 4000 # tokens - chunk_overlap: int = 200 # tokens - - # Concurrency Configuration - max_concurrent_videos: int = 5 - - # Output Configuration - default_output_dir: Path = Path("./output") - - # Transcript Configuration - default_languages: list[str] = field(default_factory=lambda: ["en"]) - - # Security: Allowed keys for environment injection - ALLOWED_KEYS: set[str] = field( - default_factory=lambda: { - "GEMINI_API_KEY", - "OPENAI_API_KEY", - "ANTHROPIC_API_KEY", - "GROQ_API_KEY", - "XAI_API_KEY", - "MISTRAL_API_KEY", - "DEFAULT_MODEL", - "OUTPUT_DIR", - "MAX_CONCURRENT_VIDEOS", - "TEMPERATURE", - "MAX_TOKENS", - } - ) - - def __post_init__(self) -> None: - """Load configuration from user config file and environment variables.""" - # First, try to load from user config file - self._load_from_user_config() - - # Then load/override with environment variables - self.gemini_api_key = os.getenv("GEMINI_API_KEY") or self.gemini_api_key - self.openai_api_key = os.getenv("OPENAI_API_KEY") or self.openai_api_key - self.anthropic_api_key = ( - os.getenv("ANTHROPIC_API_KEY") or self.anthropic_api_key - ) - self.groq_api_key = os.getenv("GROQ_API_KEY") or self.groq_api_key - self.xai_api_key = os.getenv("XAI_API_KEY") or self.xai_api_key - self.mistral_api_key = os.getenv("MISTRAL_API_KEY") or self.mistral_api_key - - # Load default model and output dir from config - env_model = os.getenv("DEFAULT_MODEL") - if env_model: - self.default_model = env_model - - env_output = os.getenv("OUTPUT_DIR") - if env_output: - self.default_output_dir = Path(env_output) - - env_concurrency = os.getenv("MAX_CONCURRENT_VIDEOS") - if env_concurrency: - try: - self.max_concurrent_videos = int(env_concurrency) - except ValueError: - logger.warning( - f"Invalid MAX_CONCURRENT_VIDEOS value: {env_concurrency}. " - f"Using default {self.max_concurrent_videos}" - ) - - env_temperature = os.getenv("TEMPERATURE") - if env_temperature: - try: - temp_value = float(env_temperature) - if not (0 <= temp_value <= 1): - logger.warning( - f"TEMPERATURE out of range [0, 1]: {env_temperature}. " - f"Using default 0.7" - ) - - else: - self.temperature = temp_value - except ValueError: - logger.warning( - f"Invalid TEMPERATURE value: {env_temperature}. Using default 0.7" - ) - - env_max_tokens = os.getenv("MAX_TOKENS") - if env_max_tokens: - try: - # self.max_tokens = int(env_max_tokens) - max_tokens_value = int(env_max_tokens) - if max_tokens_value < 1: - logger.warning( - f"MAX_TOKENS must be >= 1: {env_max_tokens}. " - f"Setting to None (default)" - ) - - else: - self.max_tokens = max_tokens_value - - except ValueError: - logger.warning( - f"Invalid MAX_TOKENS value: {env_max_tokens}. " - f"Setting to None (default)" - ) - - self._sync_env_vars() - - def _load_from_user_config(self) -> None: - """Load configuration from user's config file.""" - config_path = Path.home() / ".yt-study" / "config.env" - - if not config_path.exists(): - return - - try: - with config_path.open(encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line or line.startswith("#"): - continue - - if "=" in line: - key, value = line.split("=", 1) - key = key.strip() - value = value.strip() - - # Remove quotes if present - if (value.startswith('"') and value.endswith('"')) or ( - value.startswith("'") and value.endswith("'") - ): - value = value[1:-1] - - if key in self.ALLOWED_KEYS: - # Pre-populate env for consistency - if key not in os.environ: - os.environ[key] = value - else: - logger.warning(f"Ignoring unauthorized config key: {key}") - - except Exception as e: - logger.warning(f"Failed to load config file: {e}") - pass - - def _sync_env_vars(self) -> None: - """Sync class attributes back to os.environ for libraries that expect them.""" - if self.gemini_api_key: - os.environ["GEMINI_API_KEY"] = self.gemini_api_key - if self.openai_api_key: - os.environ["OPENAI_API_KEY"] = self.openai_api_key - if self.anthropic_api_key: - os.environ["ANTHROPIC_API_KEY"] = self.anthropic_api_key - if self.groq_api_key: - os.environ["GROQ_API_KEY"] = self.groq_api_key - if self.xai_api_key: - os.environ["XAI_API_KEY"] = self.xai_api_key - if self.mistral_api_key: - os.environ["MISTRAL_API_KEY"] = self.mistral_api_key - - def get_api_key_name_for_model(self, model: str) -> str | None: - """Get the environment variable name for the API key required by a model.""" - model_lower = model.lower() - - if "gemini" in model_lower or "vertex" in model_lower: - return "GEMINI_API_KEY" - elif "gpt" in model_lower or "openai" in model_lower: - return "OPENAI_API_KEY" - elif "claude" in model_lower or "anthropic" in model_lower: - return "ANTHROPIC_API_KEY" - elif "groq" in model_lower: - return "GROQ_API_KEY" - elif "grok" in model_lower or "xai" in model_lower: - return "XAI_API_KEY" - elif "mistral" in model_lower: - return "MISTRAL_API_KEY" - - return None - - def get_api_key_for_model(self, model: str) -> str | None: - """Get the appropriate API key value for a given model.""" - var_name = self.get_api_key_name_for_model(model) - if var_name: - return os.environ.get(var_name) - return None - - -# Global config instance -config = Config() +# Re-export everything from the new location +from yt_study.core.config import * # noqa: F401, F403 diff --git a/src/yt_study/core/__init__.py b/src/yt_study/core/__init__.py new file mode 100644 index 0000000..ecd9adb --- /dev/null +++ b/src/yt_study/core/__init__.py @@ -0,0 +1,40 @@ +"""Core pipeline module - zero UI dependencies. + +This module provides the core pipeline functionality that can be used +by any frontend (CLI, web API, GUI, etc.) without UI dependencies. + +Usage: + >>> from yt_study.core import CorePipeline, EventType + >>> + >>> pipeline = CorePipeline(model="gemini-1.5-flash") + >>> + >>> def on_progress(event): + ... if event.event_type == EventType.VIDEO_SUCCESS: + ... print(f"Done: {event.title}") + >>> + >>> result = await pipeline.run(["VIDEO_ID"], on_event=on_progress) +""" + +# Keep backward compatibility with old PipelineOrchestrator +from .orchestrator import PipelineOrchestrator +from .pipeline import ( + CorePipeline, + EventType, + PipelineEvent, + PipelineResult, + run_pipeline, + sanitize_filename, +) + + +__all__ = [ + # New core API + "CorePipeline", + "EventType", + "PipelineEvent", + "PipelineResult", + "run_pipeline", + "sanitize_filename", + # Legacy (deprecated, for backward compatibility) + "PipelineOrchestrator", +] diff --git a/src/yt_study/core/config.py b/src/yt_study/core/config.py new file mode 100644 index 0000000..315c91a --- /dev/null +++ b/src/yt_study/core/config.py @@ -0,0 +1,215 @@ +"""Configuration management for yt-study.""" + +import logging +import os +from dataclasses import dataclass, field +from pathlib import Path + + +logger = logging.getLogger(__name__) + + +@dataclass +class Config: + """ + Global configuration for the application. + + Manages loading settings from environment variables and config files. + """ + + # LLM Configuration + default_model: str = "gemini/gemini-2.0-flash" + gemini_api_key: str | None = None + openai_api_key: str | None = None + anthropic_api_key: str | None = None + groq_api_key: str | None = None + xai_api_key: str | None = None + mistral_api_key: str | None = None + + # LLM Generation Parameters + temperature: float = 0.7 + max_tokens: int | None = None + + # Chunking Configuration + chunk_size: int = 4000 # tokens + chunk_overlap: int = 200 # tokens + + # Concurrency Configuration + max_concurrent_videos: int = 5 + + # Output Configuration + default_output_dir: Path = Path("./output") + + # Transcript Configuration + default_languages: list[str] = field(default_factory=lambda: ["en"]) + + # Security: Allowed keys for environment injection + ALLOWED_KEYS: set[str] = field( + default_factory=lambda: { + "GEMINI_API_KEY", + "OPENAI_API_KEY", + "ANTHROPIC_API_KEY", + "GROQ_API_KEY", + "XAI_API_KEY", + "MISTRAL_API_KEY", + "DEFAULT_MODEL", + "OUTPUT_DIR", + "MAX_CONCURRENT_VIDEOS", + "TEMPERATURE", + "MAX_TOKENS", + } + ) + + def __post_init__(self) -> None: + """Load configuration from user config file and environment variables.""" + # First, try to load from user config file + self._load_from_user_config() + + # Then load/override with environment variables + self.gemini_api_key = os.getenv("GEMINI_API_KEY") or self.gemini_api_key + self.openai_api_key = os.getenv("OPENAI_API_KEY") or self.openai_api_key + self.anthropic_api_key = ( + os.getenv("ANTHROPIC_API_KEY") or self.anthropic_api_key + ) + self.groq_api_key = os.getenv("GROQ_API_KEY") or self.groq_api_key + self.xai_api_key = os.getenv("XAI_API_KEY") or self.xai_api_key + self.mistral_api_key = os.getenv("MISTRAL_API_KEY") or self.mistral_api_key + + # Load default model and output dir from config + env_model = os.getenv("DEFAULT_MODEL") + if env_model: + self.default_model = env_model + + env_output = os.getenv("OUTPUT_DIR") + if env_output: + self.default_output_dir = Path(env_output) + + env_concurrency = os.getenv("MAX_CONCURRENT_VIDEOS") + if env_concurrency: + try: + self.max_concurrent_videos = int(env_concurrency) + except ValueError: + logger.warning( + f"Invalid MAX_CONCURRENT_VIDEOS value: {env_concurrency}. " + f"Using default {self.max_concurrent_videos}" + ) + + env_temperature = os.getenv("TEMPERATURE") + if env_temperature: + try: + temp_value = float(env_temperature) + if not (0 <= temp_value <= 1): + logger.warning( + f"TEMPERATURE out of range [0, 1]: {env_temperature}. " + f"Using default 0.7" + ) + + else: + self.temperature = temp_value + except ValueError: + logger.warning( + f"Invalid TEMPERATURE value: {env_temperature}. Using default 0.7" + ) + + env_max_tokens = os.getenv("MAX_TOKENS") + if env_max_tokens: + try: + # self.max_tokens = int(env_max_tokens) + max_tokens_value = int(env_max_tokens) + if max_tokens_value < 1: + logger.warning( + f"MAX_TOKENS must be >= 1: {env_max_tokens}. " + f"Setting to None (default)" + ) + + else: + self.max_tokens = max_tokens_value + + except ValueError: + logger.warning( + f"Invalid MAX_TOKENS value: {env_max_tokens}. " + f"Setting to None (default)" + ) + + self._sync_env_vars() + + def _load_from_user_config(self) -> None: + """Load configuration from user's config file.""" + config_path = Path.home() / ".yt-study" / "config.env" + + if not config_path.exists(): + return + + try: + with config_path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + + if "=" in line: + key, value = line.split("=", 1) + key = key.strip() + value = value.strip() + + # Remove quotes if present + if (value.startswith('"') and value.endswith('"')) or ( + value.startswith("'") and value.endswith("'") + ): + value = value[1:-1] + + if key in self.ALLOWED_KEYS: + # Pre-populate env for consistency + if key not in os.environ: + os.environ[key] = value + else: + logger.warning(f"Ignoring unauthorized config key: {key}") + + except Exception as e: + logger.warning(f"Failed to load config file: {e}") + pass + + def _sync_env_vars(self) -> None: + """Sync class attributes back to os.environ for libraries that expect them.""" + if self.gemini_api_key: + os.environ["GEMINI_API_KEY"] = self.gemini_api_key + if self.openai_api_key: + os.environ["OPENAI_API_KEY"] = self.openai_api_key + if self.anthropic_api_key: + os.environ["ANTHROPIC_API_KEY"] = self.anthropic_api_key + if self.groq_api_key: + os.environ["GROQ_API_KEY"] = self.groq_api_key + if self.xai_api_key: + os.environ["XAI_API_KEY"] = self.xai_api_key + if self.mistral_api_key: + os.environ["MISTRAL_API_KEY"] = self.mistral_api_key + + def get_api_key_name_for_model(self, model: str) -> str | None: + """Get the environment variable name for the API key required by a model.""" + model_lower = model.lower() + + if "gemini" in model_lower or "vertex" in model_lower: + return "GEMINI_API_KEY" + elif "gpt" in model_lower or "openai" in model_lower: + return "OPENAI_API_KEY" + elif "claude" in model_lower or "anthropic" in model_lower: + return "ANTHROPIC_API_KEY" + elif "groq" in model_lower: + return "GROQ_API_KEY" + elif "grok" in model_lower or "xai" in model_lower: + return "XAI_API_KEY" + elif "mistral" in model_lower: + return "MISTRAL_API_KEY" + + return None + + def get_api_key_for_model(self, model: str) -> str | None: + """Get the appropriate API key value for a given model.""" + var_name = self.get_api_key_name_for_model(model) + if var_name: + return os.environ.get(var_name) + return None + + +# Global config instance +config = Config() diff --git a/src/yt_study/core/llm/__init__.py b/src/yt_study/core/llm/__init__.py new file mode 100644 index 0000000..2d5e18e --- /dev/null +++ b/src/yt_study/core/llm/__init__.py @@ -0,0 +1 @@ +"""LLM module for multi-provider support and content generation.""" diff --git a/src/yt_study/core/llm/generator.py b/src/yt_study/core/llm/generator.py new file mode 100644 index 0000000..0dfee2c --- /dev/null +++ b/src/yt_study/core/llm/generator.py @@ -0,0 +1,355 @@ +"""Study material generator with chunking and combining logic.""" + +import logging + +from litellm import token_counter +from rich.console import Console +from rich.progress import Progress, TaskID + +from ..config import config +from ..prompts.chapter_notes import ( + get_chapter_prompt, + get_combine_chapters_prompt, +) +from ..prompts.study_notes import ( + SYSTEM_PROMPT, + get_chunk_prompt, + get_combine_prompt, + get_single_pass_prompt, +) +from .providers import LLMProvider + + +# Re-use system prompt for now +CHAPTER_SYSTEM_PROMPT = SYSTEM_PROMPT + +console = Console() +logger = logging.getLogger(__name__) + + +class StudyMaterialGenerator: + """ + Generate study materials from transcripts using LLM. + + Handles token counting, text chunking, and recursive summarization/generation. + """ + + def __init__( + self, + provider: LLMProvider, + temperature: float = 0.7, + max_tokens: int | None = None, + ): + """ + Initialize generator. + + Args: + provider: LLM provider instance. + temperature: LLM response temperature. + max_tokens: Maximum tokens for LLM responses. + """ + self.provider = provider + self.temperature = temperature + self.max_tokens = max_tokens + + def _count_tokens(self, text: str) -> int: + """Count tokens in text using model-specific tokenizer.""" + # Note: token_counter might do network calls for some models or use + # local libraries (tiktoken). For efficiency, we assume it's fast. + try: + count = token_counter(model=self.provider.model, text=text) + return int(count) if count is not None else len(text) // 4 + except Exception: + # Fallback estimation if tokenizer fails (approx 4 chars per token) + return len(text) // 4 + + def _chunk_transcript(self, transcript: str) -> list[str]: + """ + Split transcript into chunks with overlap. + + Uses recursive chunking strategy: + - Target size: Defined in config (default 4000 tokens) + - Overlap: Defined in config (default 200 tokens) + - Priority: Sentence boundaries > Newlines > Words > Hard char limit + + Args: + transcript: The full transcript text. + + Returns: + List of text chunks. + """ + token_count = self._count_tokens(transcript) + + # Fast path: Return single chunk if within limits + if token_count <= config.chunk_size: + return [transcript] + + logger.info( + f"Transcript too long ({token_count:,} tokens), performing chunking..." + ) + + chunks: list[str] = [] + + # Strategy 1: Split by sentences + sentences = transcript.split(". ") + + # Strategy 2: Split by newlines if sentences fail + if len(sentences) < 2 and token_count > config.chunk_size: + sentences = transcript.split("\n") + + # Strategy 3: Split by spaces if newlines fail + if len(sentences) < 2: + sentences = transcript.split(" ") + + current_chunk: list[str] = [] + current_tokens = 0 + + for sentence in sentences: + sentence = sentence.strip() + if not sentence: + continue + + # Re-add delimiter for estimation (approximate) + # We assume '. ' was the delimiter for simplicity, logic holds + # for others mostly as we care about token count + term = sentence + ". " + term_tokens = self._count_tokens(term) + + # Handle edge case: Single sentence/segment is larger than chunk_size + if term_tokens > config.chunk_size: + # 1. Flush current buffer + if current_chunk: + chunks.append(" ".join(current_chunk)) + current_chunk = [] + current_tokens = 0 + + # 2. Hard split the massive segment + # Estimate char limit based on token size (conservative 3 chars/token) + char_limit = config.chunk_size * 3 + for i in range(0, len(sentence), char_limit): + sub_part = sentence[i : i + char_limit] + chunks.append(sub_part) + continue + + # Standard accumulation + if current_tokens + term_tokens > config.chunk_size: + # Chunk is full. Commit it. + if current_chunk: + chunks.append(" ".join(current_chunk)) + + # Create overlap for next chunk + overlap_chunk: list[str] = [] + overlap_tokens = 0 + + # Take sentences from the end of current_chunk until overlap limit + for prev_sent in reversed(current_chunk): + prev_tokens = self._count_tokens(prev_sent) + if overlap_tokens + prev_tokens <= config.chunk_overlap: + overlap_chunk.insert(0, prev_sent) + overlap_tokens += prev_tokens + else: + break + + current_chunk = overlap_chunk + [sentence] + current_tokens = self._count_tokens(" ".join(current_chunk)) + else: + # Should be unreachable due to check above, but safe fallback + current_chunk.append(sentence) + current_tokens += term_tokens + else: + current_chunk.append(sentence) + current_tokens += term_tokens + + # Add remaining chunk + if current_chunk: + chunks.append(" ".join(current_chunk)) + + logger.info(f"Created {len(chunks)} chunks") + return chunks + + def _update_status( + self, + progress: Progress | None, + task_id: TaskID | None, + video_title: str, + message: str, + ) -> None: + """Safe helper to update progress bar or log message.""" + if progress and task_id is not None: + short_title = ( + (video_title[:20] + "...") if len(video_title) > 20 else video_title + ) + # We assume the layout uses 'description' for the status text + progress.update( + task_id, description=f"[yellow]{short_title}[/yellow]: {message}" + ) + else: + logger.info(f"{video_title}: {message}") + + async def generate_study_notes( + self, + transcript: str, + video_title: str = "Video", + progress: Progress | None = None, + task_id: TaskID | None = None, + ) -> str: + """ + Generate study notes from transcript. + + Args: + transcript: Full video transcript text. + video_title: Video title for progress display. + progress: Optional existing progress bar instance. + task_id: Optional task ID for updating progress. + + Returns: + Complete study notes in Markdown format. + """ + chunks = self._chunk_transcript(transcript) + + # Single chunk - generate directly + if len(chunks) == 1: + self._update_status(progress, task_id, video_title, "Generating notes...") + + notes = await self.provider.generate( + system_prompt=SYSTEM_PROMPT, + user_prompt=get_single_pass_prompt(transcript), + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + if not progress: + logger.info(f"Generated notes for {video_title}") + return notes + + # Multiple chunks - generate for each, then combine + self._update_status( + progress, + task_id, + video_title, + f"Generating notes for {len(chunks)} chunks...", + ) + + chunk_notes = [] + + for i, chunk in enumerate(chunks, 1): + msg = f"Chunk {i}/{len(chunks)} (Generating)" + self._update_status(progress, task_id, video_title, msg) + + note = await self.provider.generate( + system_prompt=SYSTEM_PROMPT, + user_prompt=get_chunk_prompt(chunk), + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + chunk_notes.append(note) + + self._update_status( + progress, + task_id, + video_title, + f"Combining {len(chunk_notes)} chunk notes...", + ) + + final_notes = await self.provider.generate( + system_prompt=SYSTEM_PROMPT, + user_prompt=get_combine_prompt(chunk_notes), + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + if not progress: + logger.info(f"Completed notes for {video_title}") + + return final_notes + + async def generate_single_chapter_notes( + self, + chapter_title: str, + chapter_text: str, + ) -> str: + """ + Generate study notes for a single chapter. + + Args: + chapter_title: Title of the chapter. + chapter_text: Transcript text for the chapter. + + Returns: + Study notes for the chapter. + """ + notes = await self.provider.generate( + system_prompt=CHAPTER_SYSTEM_PROMPT, + user_prompt=get_chapter_prompt(chapter_title, chapter_text), + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + return notes + + async def generate_chapter_based_notes( + self, + chapter_transcripts: dict[str, str], + video_title: str = "Video", + progress: Progress | None = None, + task_id: TaskID | None = None, + ) -> str: + """ + Generate study notes using chapter-based approach. + + Args: + chapter_transcripts: Dictionary mapping chapter titles to transcript text. + video_title: Video title for display. + progress: Optional existing progress bar instance. + task_id: Optional task ID for updating progress. + + Returns: + Complete study notes organized by chapters. + """ + # Imports are already at top-level or can be moved up, but let's + # fix the specific issue. Previously we did lazy import inside + # function which caused issues + + self._update_status( + progress, + task_id, + video_title, + f"Generating notes for {len(chapter_transcripts)} chapters...", + ) + + chapter_notes = {} + total_chapters = len(chapter_transcripts) + + for i, (chapter_title, chapter_text) in enumerate( + chapter_transcripts.items(), 1 + ): + msg = f"Chapter {i}/{total_chapters}: {chapter_title[:20]}..." + self._update_status(progress, task_id, video_title, msg) + + # If a chapter is huge, we might need recursive chunking here too. + # For now, we assume chapters are reasonably sized or the model + # can handle ~100k context. Future improvement: Check token + # count of chapter_text and recurse if needed. + + notes = await self.provider.generate( + system_prompt=CHAPTER_SYSTEM_PROMPT, + user_prompt=get_chapter_prompt(chapter_title, chapter_text), + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + chapter_notes[chapter_title] = notes + + self._update_status( + progress, task_id, video_title, "Combining chapter notes..." + ) + + final_notes = await self.provider.generate( + system_prompt=CHAPTER_SYSTEM_PROMPT, + user_prompt=get_combine_chapters_prompt(chapter_notes), + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + + if not progress: + logger.info(f"Completed chapter-based notes for {video_title}") + + return final_notes diff --git a/src/yt_study/core/llm/providers.py b/src/yt_study/core/llm/providers.py new file mode 100644 index 0000000..625e45d --- /dev/null +++ b/src/yt_study/core/llm/providers.py @@ -0,0 +1,152 @@ +"""LLM provider configuration using LiteLLM.""" + +import logging +import os +from typing import Any + +from litellm import acompletion +from rich.console import Console + +from ..config import config + + +console = Console() +logger = logging.getLogger(__name__) + + +class LLMGenerationError(Exception): + """Exception raised when LLM generation fails.""" + + pass + + +class LLMProvider: + """ + LLM provider interface using LiteLLM. + + Handles API key verification and text generation with retries. + """ + + def __init__(self, model: str = "gemini/gemini-2.0-flash"): + """ + Initialize LLM provider. + + Args: + model: LiteLLM-compatible model string (e.g., 'gemini/gemini-2.0-flash'). + """ + self.model = model + self._validate_config() + + def _validate_config(self) -> None: + """ + Verify that the necessary API key for the selected model is set. + Logs a warning if missing. + """ + # We rely on Config to check environment variables, + # but we can double check here for the specific model + key_name = config.get_api_key_name_for_model(self.model) + if key_name: + if not os.getenv(key_name): + logger.warning( + f"API Key for model '{self.model}' ({key_name}) not found " + "in environment. Generation may fail." + ) + else: + # If we can't map the model to a specific key (unknown provider), + # we assume the user knows what they are doing or it doesn't need + # one (e.g. ollama) + logger.debug(f"No specific API key mapping found for model: {self.model}") + + async def generate( + self, + system_prompt: str, + user_prompt: str, + temperature: float = 0.7, + max_tokens: int | None = None, + ) -> str: + """ + Generate text using the configured LLM. + + Args: + system_prompt: System/instruction prompt. + user_prompt: User query/content. + temperature: Sampling temperature (0.0 to 1.0). + max_tokens: Maximum tokens to generate (None for model default). + + Returns: + Generated text content. + + Raises: + LLMGenerationError: If generation fails after retries. + """ + try: + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ] + + kwargs: dict[str, Any] = { + "model": self.model, + "messages": messages, + "temperature": temperature, + # LiteLLM handles exponential backoff for RateLimitError + "num_retries": 3, + } + + if max_tokens: + kwargs["max_tokens"] = max_tokens + + # LiteLLM's acompletion handles async requests to various providers + response = await acompletion(**kwargs) + + # safely extract content + if not response.choices or not response.choices[0].message.content: + raise LLMGenerationError("Received empty response from LLM provider") + + content = response.choices[0].message.content.strip() + return self._clean_content(content) + + except Exception as e: + logger.error(f"LLM generation failed with {self.model}: {e}", exc_info=True) + raise LLMGenerationError( + f"Failed to generate with {self.model}: {str(e)}" + ) from e + + def _clean_content(self, content: str) -> str: + """ + Remove markdown code block fencing if the LLM wraps the entire output in it. + + Args: + content: Raw LLM output. + + Returns: + Cleaned content string. + """ + # Check for triple backticks + if content.startswith("```"): + lines = content.splitlines() + # Need at least fence start, content, fence end + if len(lines) >= 2 and lines[0].strip().startswith("```"): + # If the first line is just a fence (with optional language), remove it + # Check if the last line is also a fence + if lines[-1].strip() == "```": + return "\n".join(lines[1:-1]).strip() + # Sometimes LLMs stop abruptly or formatting is weird; + # if it starts with fence, we strip the first line. + # If it ends with fence, strip that too. + return "\n".join(lines[1:]).strip().removesuffix("```").strip() + + return content + + +def get_provider(model: str = "gemini/gemini-2.0-flash") -> LLMProvider: + """ + Factory function to get an LLM provider instance. + + Args: + model: LiteLLM-compatible model string. + + Returns: + Configured LLMProvider instance. + """ + return LLMProvider(model=model) diff --git a/src/yt_study/core/orchestrator.py b/src/yt_study/core/orchestrator.py new file mode 100644 index 0000000..c6c67bd --- /dev/null +++ b/src/yt_study/core/orchestrator.py @@ -0,0 +1,534 @@ +"""Main pipeline orchestrator with concurrent processing.""" + +import asyncio +import logging +import re +from pathlib import Path + +from rich.console import Console +from rich.live import Live +from rich.panel import Panel +from rich.progress import Progress, TaskID +from rich.table import Table + +from ..ui.dashboard import PipelineDashboard +from .config import config +from .llm.generator import StudyMaterialGenerator +from .llm.providers import get_provider +from .youtube.metadata import ( + get_playlist_info, + get_video_chapters, + get_video_duration, + get_video_title, +) +from .youtube.parser import parse_youtube_url +from .youtube.playlist import extract_playlist_videos +from .youtube.transcript import ( + YouTubeIPBlockError, + fetch_transcript, + split_transcript_by_chapters, +) + + +console = Console() +logger = logging.getLogger(__name__) + + +def sanitize_filename(name: str) -> str: + """ + Sanitize a string to be used as a filename. + + Args: + name: Raw filename string. + + Returns: + Sanitized string safe for file systems. + """ + # Remove or replace invalid characters + name = re.sub(r'[<>:"/\\|?*]', "", name) + # Replace multiple spaces with single space + name = re.sub(r"\s+", " ", name) + # Trim and limit length + name = name.strip()[:100] + return name if name else "untitled" + + +class PipelineOrchestrator: + """ + Orchestrates the end-to-end pipeline for video processing. + + Manages concurrency, error handling, and UI updates. + """ + + def __init__( + self, + model: str = "gemini/gemini-2.0-flash", + output_dir: Path | None = None, + languages: list[str] | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + ): + """ + Initialize orchestrator. + + Args: + model: LLM model string. + output_dir: Output directory path. + languages: Preferred transcript languages. + temperature: LLM temperature (defaults to config.temperature). + max_tokens: Max tokens (defaults to config.max_tokens). + """ + self.model = model + self.output_dir = output_dir or config.default_output_dir + self.languages = languages or config.default_languages + self.temperature = ( + temperature if temperature is not None else config.temperature + ) + self.max_tokens = max_tokens if max_tokens is not None else config.max_tokens + self.provider = get_provider(model) + self.generator = StudyMaterialGenerator( + self.provider, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + self.semaphore = asyncio.Semaphore(config.max_concurrent_videos) + + def validate_provider(self) -> bool: + """ + Validate that the API key for the selected provider is set. + + Returns: + True if valid (or warning logged), False if critical missing config. + """ + key_name = config.get_api_key_name_for_model(self.model) + + if key_name: + import os + + if not os.environ.get(key_name): + console.print( + f"\n[red bold]✗ Missing API Key for {self.model}[/red bold]" + ) + console.print( + f"[yellow]Expected environment variable: {key_name}[/yellow]" + ) + console.print( + "[dim]Please check your .env file or run:[/dim] " + "[cyan]yt-study setup[/cyan]\n" + ) + return False + + return True + + async def process_video( + self, + video_id: str, + output_path: Path, + progress: Progress | None = None, + task_id: TaskID | None = None, + video_title: str | None = None, + is_playlist: bool = False, + ) -> bool: + """ + Process a single video: fetch transcript and generate study notes. + + Args: + video_id: YouTube Video ID. + output_path: Destination path for the MD file. + progress: Rich Progress instance. + task_id: Rich TaskID. + video_title: Pre-fetched title (optional). + is_playlist: Whether this is part of a playlist (affects UI logging). + + Returns: + True on success, False on failure. + """ + async with self.semaphore: + local_task_id = task_id + + # If standalone (not part of worker pool), create a specific + # bar if requested + if is_playlist and progress and task_id is None: + display_title = (video_title or video_id)[:30] + local_task_id = progress.add_task( + description=f"[cyan]⏳ {display_title}... (Waiting)[/cyan]", + total=None, + ) + + try: + # 1. Fetch Metadata + if not video_title: + # Run in thread to avoid blocking + video_title = await asyncio.to_thread(get_video_title, video_id) + + # Fetch duration and chapters concurrently + duration, chapters = await asyncio.gather( + asyncio.to_thread(get_video_duration, video_id), + asyncio.to_thread(get_video_chapters, video_id), + ) + + title_display = (video_title or video_id)[:40] + + if progress and local_task_id is not None: + progress.update( + local_task_id, + description=f"[cyan]📥 {title_display}... (Transcript)[/cyan]", + ) + + # 2. Fetch Transcript + transcript_obj = await fetch_transcript(video_id, self.languages) + + # 3. Determine Generation Strategy + # Use chapters if video is long (>1h) and chapters exist + use_chapters = duration > 3600 and len(chapters) > 0 and not is_playlist + + if use_chapters: + if progress and local_task_id is not None: + progress.update( + local_task_id, + description=( + f"[cyan]📖 {title_display}... (Chapters)[/cyan]" + ), + ) + # else block removed as redundant + + # Split transcript + chapter_transcripts = split_transcript_by_chapters( + transcript_obj, chapters + ) + + # Create folder for chapter notes + safe_title = sanitize_filename(video_title) + output_folder = self.output_dir / safe_title + output_folder.mkdir(parents=True, exist_ok=True) + + # Generate chapter notes + # Fix: Iterate here and call generator for each chapter + # to save individually + + for i, (chap_title, chap_text) in enumerate( + chapter_transcripts.items(), 1 + ): + status_msg = f"Chapter {i}/{len(chapter_transcripts)}" + if progress and local_task_id is not None: + progress.update( + local_task_id, + description=( + f"[cyan]🤖 {title_display}... ({status_msg})[/cyan]" + ), + ) + + notes = await self.generator.generate_single_chapter_notes( + chapter_title=chap_title, + chapter_text=chap_text, + ) + + # Save individual chapter + safe_chapter = sanitize_filename(chap_title) + chapter_file = output_folder / f"{i:02d}_{safe_chapter}.md" + chapter_file.write_text(notes, encoding="utf-8") + + if progress and local_task_id is not None: + progress.update( + local_task_id, + description=f"[green]✓ {title_display} (Done)[/green]", + completed=True, + ) + + return True + + else: + # Single file generation + transcript_text = transcript_obj.to_text() + + if progress and local_task_id is not None: + progress.update( + local_task_id, + description=( + f"[cyan]🤖 {title_display}... (Generating)[/cyan]" + ), + ) + + notes = await self.generator.generate_study_notes( + transcript_text, + video_title=title_display, + progress=progress, + task_id=local_task_id, + ) + + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(notes, encoding="utf-8") + + if progress and local_task_id is not None: + progress.update( + local_task_id, + description=f"[green]✓ {title_display} (Done)[/green]", + completed=True, + ) + + return True + + except Exception as e: + logger.error(f"Failed to process {video_id}: {e}") + + err_msg = str(e) + if isinstance(e, YouTubeIPBlockError) or ( + "blocking requests" in err_msg + ): + err_display = "[bold red]IP BLOCKED[/bold red]" + console.print( + Panel( + "[bold red]🚫 YouTube IP Block Detected[/bold red]\n\n" + "YouTube is limiting requests from your IP address.\n" + "[yellow]➤ Recommendation:[/yellow] Use a VPN or " + "wait ~1 hour.", + border_style="red", + ) + ) + else: + err_display = "(Failed)" + + if progress and local_task_id is not None: + progress.update( + local_task_id, + description=( + f"[red]✗ {(video_title or video_id)[:20]}... " + f"{err_display}[/red]" + ), + visible=True, + ) + + return False + + async def _process_with_dashboard( + self, + video_ids: list[str], + playlist_name: str = "Queue", + is_single_video: bool = False, + ) -> int: + """Process a list of videos using the Advanced Dashboard UI.""" + from ..ui.dashboard import PipelineDashboard + + # Initialize Dashboard FIRST to capture all output + # Adjust concurrency display: if total_videos < max_concurrency, + # only show needed workers + actual_concurrency = min(len(video_ids), config.max_concurrent_videos) + + dashboard = PipelineDashboard( + total_videos=len(video_ids), + concurrency=actual_concurrency, + playlist_name=playlist_name, + model_name=self.model, + ) + + success_count = 0 + video_titles = {} + + # Run Live Display (inline, not full screen) + # We start it immediately to show "Fetching metadata..." state + with Live(dashboard, refresh_per_second=10, console=console, screen=False): + # --- Phase 1: Metadata Fetching --- + TITLE_FETCH_CONCURRENCY = 10 + if not is_single_video: + dashboard.update_overall_status( + f"[cyan]📋 Fetching metadata for {len(video_ids)} videos...[/cyan]" + ) + + title_semaphore = asyncio.Semaphore(TITLE_FETCH_CONCURRENCY) + + async def fetch_title_safe(vid: str) -> str: + async with title_semaphore: + try: + return await asyncio.to_thread(get_video_title, vid) + except Exception: + return vid + + # Fetch titles + titles = await asyncio.gather(*(fetch_title_safe(vid) for vid in video_ids)) + video_titles = dict(zip(video_ids, titles, strict=True)) + + # --- Phase 2: Processing --- + if not is_single_video: + dashboard.update_overall_status("[bold blue]Total Progress[/bold blue]") + + # Determine base output folder + if is_single_video: + base_folder = self.output_dir + else: + base_folder = self.output_dir / sanitize_filename(playlist_name) + base_folder.mkdir(parents=True, exist_ok=True) + + # Worker Queue Implementation + queue: asyncio.Queue[str] = asyncio.Queue() + for vid in video_ids: + queue.put_nowait(vid) + + async def worker(worker_idx: int, task_id: TaskID) -> None: + nonlocal success_count + while not queue.empty(): + try: + video_id = queue.get_nowait() + except asyncio.QueueEmpty: + break + + title = video_titles.get(video_id, video_id) + safe_title = sanitize_filename(title) + + if is_single_video: + video_folder = base_folder / safe_title + output_path = video_folder / f"{safe_title}.md" + else: + output_path = base_folder / f"{safe_title}.md" + + # Update status + dashboard.update_worker( + worker_idx, f"[yellow]{title[:30]}...[/yellow]" + ) + + try: + result = await self.process_video( + video_id, + output_path, + progress=dashboard.worker_progress, + task_id=task_id, + video_title=title, + is_playlist=not is_single_video, + ) + + if result: + success_count += 1 + dashboard.add_completion(title) + else: + dashboard.add_failure(title) + + except Exception as e: + logger.error(f"Worker {worker_idx} failed on {video_id}: {e}") + dashboard.update_worker(worker_idx, f"[red]Error: {e}[/red]") + dashboard.add_failure(title) + await asyncio.sleep(2) + finally: + queue.task_done() + + # Worker done + dashboard.update_worker(worker_idx, "[dim]Idle[/dim]") + + try: + workers = [ + asyncio.create_task(worker(i, dashboard.worker_tasks[i])) + for i in range(actual_concurrency) + ] + await asyncio.gather(*workers) + except Exception as e: + logger.error(f"Dashboard execution failed: {e}") + + # Print summary table after dashboard closes + self._print_summary(dashboard) + + return success_count + + def _print_summary(self, dashboard: "PipelineDashboard") -> None: + """Print a summary table of the run.""" + if not dashboard.recent_completions and not dashboard.recent_failures: + return + + summary_table = Table( + title="📊 Processing Summary", + border_style="cyan", + show_header=True, + header_style="bold magenta", + ) + summary_table.add_column("Status", justify="center") + summary_table.add_column("Video Title", style="dim") + + # Add failures first (more important) + if dashboard.recent_failures: + for fail in dashboard.recent_failures: + summary_table.add_row("[bold red]FAILED[/bold red]", fail) + + # Add successes + if dashboard.recent_completions: + for comp in dashboard.recent_completions: + summary_table.add_row("[green]SUCCESS[/green]", comp) + + console.print("\n") + console.print(summary_table) + console.print( + f"\n[bold]Total Completed:[/bold] " + f"{dashboard.overall_progress.tasks[0].completed}/" + f"{dashboard.overall_progress.tasks[0].total}" + ) + console.print("[dim]Check logs for detailed error reports.[/dim]\n") + + async def process_playlist( + self, playlist_id: str, playlist_name: str = "playlist" + ) -> int: + """Process playlist with concurrent dynamic progress bars.""" + video_ids = await extract_playlist_videos(playlist_id) + return await self._process_with_dashboard(video_ids, playlist_name) + + async def run(self, url: str) -> None: + """ + Run the pipeline for a given YouTube URL. + + Args: + url: YouTube video or playlist URL. + """ + # Validate Provider Credentials + if not self.validate_provider(): + return + + try: + # Parse URL + parsed = parse_youtube_url(url) + + if parsed.url_type == "video": + if not parsed.video_id: + console.print("[red]Error: Video ID could not be extracted[/red]") + return + + await self._process_with_dashboard( + [parsed.video_id], + playlist_name="Single Video", + is_single_video=True, + ) + + # Summary is already printed by _process_with_dashboard + + elif parsed.url_type == "playlist": + if not parsed.playlist_id: + console.print( + "[red]Error: Playlist ID could not be extracted[/red]" + ) + return + + # Fetch basic playlist info first - handled in dashboard now + # if needed or kept minimal. Actually, playlist title fetching + # is useful to show BEFORE starting but _process_with_dashboard + # fetches metadata anyway. + # However, to pass playlist_name to dashboard, we might want it. + # But waiting for title can be slow. + # Let's let the dashboard handle titles for videos. + # For playlist title, we can try fast fetch or default to ID. + + # Fetching playlist title here is blocking/slow if not careful. + # Let's just use ID as name initially or fetch it quickly. + # The original code did fetch it. + + # To reduce redundancy, we remove the print statement + # "Playlist: ..." + playlist_title, _ = await asyncio.to_thread( + get_playlist_info, parsed.playlist_id + ) + + # Removed redundant print: + # console.print(f"[cyan]📑 Playlist:[/cyan] {playlist_title}\n") + + await self.process_playlist(parsed.playlist_id, playlist_title) + + # Summary handled by dashboard + + except ValueError as e: + console.print(f"[red]Input Error: {e}[/red]") + except Exception as e: + console.print(f"[red]Unexpected Error: {e}[/red]") + logger.exception("Pipeline run failed") diff --git a/src/yt_study/core/pipeline.py b/src/yt_study/core/pipeline.py new file mode 100644 index 0000000..73dfcff --- /dev/null +++ b/src/yt_study/core/pipeline.py @@ -0,0 +1,427 @@ +""" +Core pipeline orchestrator with concurrent processing. + +This module provides the single entry point for the pipeline. +All UI concerns are handled externally via event callbacks. +No Rich, no Console, no Dashboard imports here. +""" + +import asyncio +import logging +import re +from collections.abc import Callable +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + +from .config import config +from .llm.generator import StudyMaterialGenerator +from .llm.providers import get_provider +from .youtube.metadata import ( + get_video_chapters, + get_video_duration, + get_video_title, +) +from .youtube.transcript import ( + YouTubeIPBlockError, + fetch_transcript, + split_transcript_by_chapters, +) + + +logger = logging.getLogger(__name__) + + +class EventType(Enum): + """Event types emitted by the pipeline.""" + + METADATA_START = "metadata_start" + METADATA_FETCHED = "metadata_fetched" + TRANSCRIPT_FETCHING = "transcript_fetching" + TRANSCRIPT_FETCHED = "transcript_fetched" + GENERATION_START = "generation_start" + CHAPTER_GENERATING = "chapter_generating" + GENERATION_COMPLETE = "generation_complete" + VIDEO_SUCCESS = "video_success" + VIDEO_FAILED = "video_failed" + PIPELINE_START = "pipeline_start" + PIPELINE_COMPLETE = "pipeline_complete" + + +@dataclass +class PipelineEvent: + """Event emitted during pipeline execution.""" + + event_type: EventType + video_id: str + title: str | None = None + chapter_number: int | None = None + total_chapters: int | None = None + error: str | None = None + output_path: Path | None = None + + +@dataclass +class PipelineResult: + """Result of pipeline execution.""" + + success_count: int + failure_count: int + total_count: int + video_ids: list[str] + errors: dict[str, str] # video_id -> error message + + +def sanitize_filename(name: str) -> str: + """ + Sanitize a string to be used as a filename. + + Args: + name: Raw filename string. + + Returns: + Sanitized string safe for file systems. + """ + name = re.sub(r'[<>:"/\\|?*]', "", name) + name = re.sub(r"\s+", " ", name) + name = name.strip()[:100] + return name if name else "untitled" + + +class CorePipeline: + """ + Core pipeline orchestrator. + + Pure business logic - no UI concerns. + Single public entry point: `run()` + + Communicates progress via event callbacks. + """ + + def __init__( + self, + model: str = "gemini/gemini-2.0-flash", + output_dir: Path | None = None, + languages: list[str] | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + ): + """ + Initialize the core pipeline. + + Args: + model: LLM model string. + output_dir: Output directory path. + languages: Preferred transcript languages. + temperature: LLM temperature. + max_tokens: Max tokens for generation. + """ + self.model = model + self.output_dir = output_dir or config.default_output_dir + self.languages = languages or config.default_languages + self.temperature = ( + temperature if temperature is not None else config.temperature + ) + self.max_tokens = max_tokens if max_tokens is not None else config.max_tokens + + self.provider = get_provider(model) + self.generator = StudyMaterialGenerator( + self.provider, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + self.semaphore = asyncio.Semaphore(config.max_concurrent_videos) + self.errors: dict[str, str] = {} + + def _check_api_key(self) -> bool: + """ + Check if API key is configured. + + Returns: + True if valid, False otherwise. + Errors are logged but not printed (UI's responsibility). + """ + key_name = config.get_api_key_name_for_model(self.model) + + if key_name: + import os + + if not os.environ.get(key_name): + logger.error(f"Missing API Key for {self.model}. Expected: {key_name}") + return False + return True + + async def _process_single_video( + self, + video_id: str, + output_path: Path, + on_event: Callable[[PipelineEvent], None] | None = None, + ) -> bool: + """ + Process a single video: fetch transcript and generate study notes. + + This is an INTERNAL method (async worker). + It emits events that the CLI/UI can listen to. + + Args: + video_id: YouTube Video ID. + output_path: Destination path for the MD file. + on_event: Callback for progress events. + + Returns: + True on success, False on failure. + """ + async with self.semaphore: + try: + # --- Metadata Phase --- + emit = self._emit_event(on_event) + emit(EventType.METADATA_START, video_id) + + # Fetch metadata concurrently + title = await asyncio.to_thread(get_video_title, video_id) + duration, chapters = await asyncio.gather( + asyncio.to_thread(get_video_duration, video_id), + asyncio.to_thread(get_video_chapters, video_id), + ) + + emit( + EventType.METADATA_FETCHED, + video_id, + title=title, + chapter_number=len(chapters) if chapters else 0, + ) + + # --- Transcript Phase --- + emit(EventType.TRANSCRIPT_FETCHING, video_id, title=title) + + transcript_obj = await fetch_transcript(video_id, self.languages) + + emit(EventType.TRANSCRIPT_FETCHED, video_id, title=title) + + # --- Generation Strategy --- + use_chapters = duration > 3600 and len(chapters) > 0 + + if use_chapters: + # Chapter-based generation + chapter_transcripts = split_transcript_by_chapters( + transcript_obj, chapters + ) + + safe_title = sanitize_filename(title) + output_folder = self.output_dir / safe_title + output_folder.mkdir(parents=True, exist_ok=True) + + total_chapters = len(chapter_transcripts) + + for i, (chap_title, chap_text) in enumerate( + chapter_transcripts.items(), 1 + ): + emit( + EventType.CHAPTER_GENERATING, + video_id, + title=title, + chapter_number=i, + total_chapters=total_chapters, + ) + + notes = await self.generator.generate_single_chapter_notes( + chapter_title=chap_title, + chapter_text=chap_text, + ) + + safe_chapter = sanitize_filename(chap_title) + chapter_file = output_folder / f"{i:02d}_{safe_chapter}.md" + chapter_file.write_text(notes, encoding="utf-8") + + emit( + EventType.GENERATION_COMPLETE, + video_id, + title=title, + output_path=output_folder, + ) + emit(EventType.VIDEO_SUCCESS, video_id, title=title) + return True + + else: + # Single file generation + emit(EventType.GENERATION_START, video_id, title=title) + + transcript_text = transcript_obj.to_text() + notes = await self.generator.generate_study_notes( + transcript_text, + video_title=title, + ) + + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text(notes, encoding="utf-8") + + emit( + EventType.GENERATION_COMPLETE, + video_id, + title=title, + output_path=output_path, + ) + emit(EventType.VIDEO_SUCCESS, video_id, title=title) + return True + + except YouTubeIPBlockError as e: + error_msg = "YouTube IP blocked - use VPN or wait 1 hour" + logger.error(f"IP Block for {video_id}: {e}") + self.errors[video_id] = error_msg + emit(EventType.VIDEO_FAILED, video_id, error=error_msg) + return False + + except Exception as e: + error_msg = f"{type(e).__name__}: {str(e)}" + logger.error(f"Failed to process {video_id}: {e}", exc_info=True) + self.errors[video_id] = error_msg + emit(EventType.VIDEO_FAILED, video_id, error=error_msg) + return False + + def _emit_event( + self, + on_event: Callable[[PipelineEvent], None] | None, + ) -> Callable[..., None]: + """ + Create a helper function to emit events. + + This allows cleaner event emission throughout the code. + """ + + def emit( + event_type: EventType, + video_id: str, + title: str | None = None, + chapter_number: int | None = None, + total_chapters: int | None = None, + error: str | None = None, + output_path: Path | None = None, + ) -> None: + if on_event: + event = PipelineEvent( + event_type=event_type, + video_id=video_id, + title=title, + chapter_number=chapter_number, + total_chapters=total_chapters, + error=error, + output_path=output_path, + ) + try: + on_event(event) + except Exception as e: + logger.warning(f"Event handler error: {e}") + + return emit + + async def run( + self, + video_ids: list[str], + on_event: Callable[[PipelineEvent], None] | None = None, + ) -> PipelineResult: + """ + ✅ SINGLE ENTRY POINT FOR CLI AND OTHER FRONTENDS + + Process a list of video IDs concurrently. + + Args: + video_ids: List of YouTube video IDs to process. + on_event: Optional callback for progress events. + Signature: (event: PipelineEvent) -> None + + Returns: + PipelineResult with success count, failures, and detailed errors. + + Example (CLI usage): + >>> pipeline = CorePipeline(model="gemini-1.5-flash") + >>> + >>> def on_progress(event): + ... if event.event_type == EventType.VIDEO_SUCCESS: + ... print(f"✓ {event.title}") + ... elif event.event_type == EventType.VIDEO_FAILED: + ... print(f"✗ {event.title}: {event.error}") + >>> + >>> result = await pipeline.run( + ... ["VIDEO_ID_1", "VIDEO_ID_2"], + ... on_event=on_progress + ... ) + >>> print(f"Completed: {result.success_count}/{result.total_count}") + """ + # --- Validation --- + if not self._check_api_key(): + return PipelineResult( + success_count=0, + failure_count=len(video_ids), + total_count=len(video_ids), + video_ids=video_ids, + errors={vid: "Missing API key" for vid in video_ids}, + ) + + if not video_ids: + return PipelineResult( + success_count=0, + failure_count=0, + total_count=0, + video_ids=[], + errors={}, + ) + + emit = self._emit_event(on_event) + emit(EventType.PIPELINE_START, video_ids[0]) + + self.errors.clear() + success_count = 0 + + # --- Process all videos concurrently --- + tasks = [] + for video_id in video_ids: + safe_title = sanitize_filename(video_id) + output_path = self.output_dir / f"{safe_title}.md" + + task = self._process_single_video( + video_id, + output_path, + on_event=on_event, + ) + tasks.append(task) + + # Gather results + results = await asyncio.gather(*tasks, return_exceptions=False) + success_count = sum(1 for r in results if r is True) + + # --- Return Structured Result --- + failure_count = len(video_ids) - success_count + + emit(EventType.PIPELINE_COMPLETE, video_ids[0]) + + return PipelineResult( + success_count=success_count, + failure_count=failure_count, + total_count=len(video_ids), + video_ids=video_ids, + errors=self.errors, + ) + + +async def run_pipeline( + video_ids: list[str], + output_dir: Path | None = None, + model: str = "gemini/gemini-2.0-flash", + on_event: Callable[[PipelineEvent], None] | None = None, +) -> PipelineResult: + """ + Convenience function for simple usage. + + Alternative to CorePipeline class. + + Args: + video_ids: List of YouTube video IDs to process. + output_dir: Optional output directory. + model: LLM model string. + on_event: Optional callback for progress events. + + Returns: + PipelineResult with success/failure counts. + """ + pipeline = CorePipeline(model=model, output_dir=output_dir) + return await pipeline.run(video_ids, on_event=on_event) diff --git a/src/yt_study/core/prompts/__init__.py b/src/yt_study/core/prompts/__init__.py new file mode 100644 index 0000000..5bc08d8 --- /dev/null +++ b/src/yt_study/core/prompts/__init__.py @@ -0,0 +1 @@ +"""Prompt templates for study material generation.""" diff --git a/src/yt_study/core/prompts/chapter_notes.py b/src/yt_study/core/prompts/chapter_notes.py new file mode 100644 index 0000000..2e88328 --- /dev/null +++ b/src/yt_study/core/prompts/chapter_notes.py @@ -0,0 +1,55 @@ +"""Prompt templates for chapter-based study material generation.""" + +# Prompt for generating notes from a single chapter +CHAPTER_GENERATION_PROMPT = """ +Create an in-depth, detailed study guide for this specific chapter: + +Chapter Title: {chapter_title} + +Transcript: +{transcript_chunk} + +Requirements: +1. **Deep Dive**: Provide a thorough, granular explanation of the chapter's topic. +2. **Comprehensive**: Include every nuance, sub-point, and detail mentioned. +3. **Clarify Concepts**: Explain "why" and "how" for every concept, not just "what". +4. **Examples**: Preserve all examples and use them to illustrate technical points. +5. **Structure**: Use deeply nested headers (###, ####) to break down complex ideas. +6. Pure Markdown format. +7. English language. +8. **DO NOT include any opening or closing conversational text.** +9. **Start directly with the first header (e.g., # Chapter Title)**""" + + +# Prompt for combining chapter notes +COMBINE_CHAPTER_NOTES_PROMPT = """ +You have generated study notes for different chapters of the same video. +Combine these chapter notes into a single, well-organized study document. + +Video chapters and notes: +{chapter_notes} + +Requirements: +1. Keep chapter structure with clear headers (## Chapter Title) +2. Ensure logical flow between chapters +3. Remove redundancies while preserving all unique information +4. Add a brief introduction summarizing what the video covers +5. Maintain all important details from every chapter +6. Use proper Markdown hierarchy (##, ###, etc.) +7. Do NOT add a table of contents +8. Create a cohesive document that's easy to navigate and review""" + + +def get_chapter_prompt(chapter_title: str, transcript_chunk: str) -> str: + """Generate prompt for a chapter.""" + return CHAPTER_GENERATION_PROMPT.format( + chapter_title=chapter_title, transcript_chunk=transcript_chunk + ) + + +def get_combine_chapters_prompt(chapter_notes: dict[str, str]) -> str: + """Generate prompt for combining chapter notes.""" + combined = "\n\n".join( + [f"## {title}\n\n{notes}" for title, notes in chapter_notes.items()] + ) + return COMBINE_CHAPTER_NOTES_PROMPT.format(chapter_notes=combined) diff --git a/src/yt_study/core/prompts/study_notes.py b/src/yt_study/core/prompts/study_notes.py new file mode 100644 index 0000000..02b3f6d --- /dev/null +++ b/src/yt_study/core/prompts/study_notes.py @@ -0,0 +1,106 @@ +"""Prompt templates for study material generation and chunk combining.""" + +# System prompt for generating study notes from transcript chunks +SYSTEM_PROMPT = """ +You are an expert academic tutor and technical writer dedicated to creating +the most comprehensive study materials possible. + +Your goal is to transform video transcripts into deep, detailed, and highly +structured study notes. +You prioritize: +- **Depth**: Go beyond surface-level summaries. Explain *why* and *how*, not + just *what*. +- **Comprehensive Coverage**: Capture every single concept, detail, nuance, + and example mentioned. +- **Clarity**: Use clear, academic yet accessible language. Break down complex topics. +- **Structure**: Use logical hierarchy (headers, subheaders) to organize + information effectively. + +Always generate output in clean Markdown format.""" + +# User prompt for individual transcript chunks +CHUNK_GENERATION_PROMPT = """ +Create extremely detailed and in-depth study notes from this transcript +segment: + +{transcript_chunk} + +Requirements: +1. **Comprehensive Coverage**: Cover EVERY concept, definition, theory, and + significant detail mentioned. Do not summarize; expand. +2. **In-Depth Explanation**: Explain complex ideas thoroughly. If a process + is described, break it down step-by-step. +3. **Capture Examples & Code**: Include ALL examples, case studies, and + especially **CODE BLOCKS/SQL** provided in the transcript. +4. **Technical Precision**: Use actual SQL syntax for table definitions + (e.g., `CREATE TABLE`), not just descriptions. +5. **Logical Structure**: Use deep hierarchy (##, ###, ####) to organize + related concepts. +6. **Key Terminology**: Highlight and define technical terms or important vocabulary. +7. **Pure Markdown**: No HTML, no table of contents. +8. **Clean Start**: Start directly with the content headers, no conversational filler. +9. **Language**: English.""" + +# Prompt for combining multiple chunk notes into final document +COMBINE_CHUNKS_PROMPT = """ +You have generated study notes for multiple segments of the same video. Now +combine these segments into a single, coherent study document. + +Segment notes: +{chunk_notes} + +Requirements: +1. Merge all segments into a unified, flowing document +2. **Preserve ALL Content**: Do NOT summarize or condense. Retain all + explanations, examples, code blocks, and details. +3. **Preserve Code & Syntax**: Use valid `CREATE TABLE` SQL and other + specific syntax exactly as presented. +4. **Seamless Merge**: Connect segments smoothly, but do not delete content for brevity. +5. **Detailed & Comprehensive**: The final document must be as detailed as + the input segments combined. +6. Maintain consistent formatting and structure (##, ###). +7. Do NOT add a table of contents. +8. **Example clean output:** "# Title\\n\\n## Section 1..." + +Create study notes that are comprehensive, well-organized, and easy to review.""" + +# Prompt for single-pass generation (small transcripts) +SINGLE_PASS_PROMPT = """ +Create an extensive and in-depth study guide from this complete video +transcript: + +{transcript} + +Requirements: +1. **Exhaustive Coverage**: Cover every single topic discussed. Do not leave + out details. +2. **Deep Understanding**: Explain concepts clearly and thoroughly, as if + teaching a student. +3. **Structured Learning**: Use a clear, logical hierarchy (##, ###, ####) + to organize topics. +4. **Examples & Context**: Retain all illustrative examples and context + provided in the video. +5. **No Summarization**: Do not summarize brief points; expand them for full + understanding. +6. Pure Markdown format (no HTML, no table of contents). +7. English language output. +8. **Clean Start**: Start directly with the first header (e.g. # Video + Title), no filler.""" + + +def get_chunk_prompt(transcript_chunk: str) -> str: + """Generate prompt for a transcript chunk.""" + return CHUNK_GENERATION_PROMPT.format(transcript_chunk=transcript_chunk) + + +def get_combine_prompt(chunk_notes: list[str]) -> str: + """Generate prompt for combining chunk notes.""" + combined = "\n\n---\n\n".join( + [f"## Segment {i + 1}\n\n{note}" for i, note in enumerate(chunk_notes)] + ) + return COMBINE_CHUNKS_PROMPT.format(chunk_notes=combined) + + +def get_single_pass_prompt(transcript: str) -> str: + """Generate prompt for single-pass generation.""" + return SINGLE_PASS_PROMPT.format(transcript=transcript) diff --git a/src/yt_study/core/setup_wizard.py b/src/yt_study/core/setup_wizard.py new file mode 100644 index 0000000..cb52429 --- /dev/null +++ b/src/yt_study/core/setup_wizard.py @@ -0,0 +1,420 @@ +"""Configuration wizard for yt-study.""" + +from pathlib import Path +from typing import Any + +from rich.console import Console +from rich.panel import Panel +from rich.prompt import Confirm, Prompt +from rich.table import Table + + +console = Console() + + +# API key configuration for different providers +PROVIDER_CONFIG: dict[str, dict[str, Any]] = { + "gemini": { + "name": "Google Gemini", + "env_var": "GEMINI_API_KEY", + "api_url": "https://aistudio.google.com/app/apikey", + "keywords": ["gemini", "vertex"], + }, + "openai": { + "name": "OpenAI (ChatGPT)", + "env_var": "OPENAI_API_KEY", + "api_url": "https://platform.openai.com/api-keys", + "keywords": ["gpt", "openai", "o1"], + }, + "anthropic": { + "name": "Anthropic (Claude)", + "env_var": "ANTHROPIC_API_KEY", + "api_url": "https://console.anthropic.com/settings/keys", + "keywords": ["claude", "anthropic"], + }, + "groq": { + "name": "Groq", + "env_var": "GROQ_API_KEY", + "api_url": "https://console.groq.com/keys", + "keywords": ["groq"], + }, + "xai": { + "name": "xAI (Grok)", + "env_var": "XAI_API_KEY", + "api_url": "https://console.x.ai/", + "keywords": ["grok", "xai"], + }, + "mistral": { + "name": "Mistral AI", + "env_var": "MISTRAL_API_KEY", + "api_url": "https://console.mistral.ai/api-keys/", + "keywords": ["mistral"], + }, + "cohere": { + "name": "Cohere", + "env_var": "COHERE_API_KEY", + "api_url": "https://dashboard.cohere.com/api-keys", + "keywords": ["cohere", "command"], + }, + "deepseek": { + "name": "DeepSeek", + "env_var": "DEEPSEEK_API_KEY", + "api_url": "https://platform.deepseek.com/api_keys", + "keywords": ["deepseek"], + }, +} + + +def get_config_path() -> Path: + """Get path to user config file.""" + config_dir = Path.home() / ".yt-study" + config_dir.mkdir(exist_ok=True) + return config_dir / "config.env" + + +def load_config() -> dict[str, str]: + """Load existing configuration.""" + config_path = get_config_path() + loaded_config = {} + + if config_path.exists(): + try: + with config_path.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if line and not line.startswith("#") and "=" in line: + key, value = line.split("=", 1) + loaded_config[key.strip()] = value.strip() + except Exception: + # If corrupted, return empty + pass + + return loaded_config + + +def save_config(new_config: dict[str, str]) -> None: + """ + Save configuration to file, preserving existing keys. + + Args: + new_config: Dictionary of new configuration values to merge/update. + """ + config_path = get_config_path() + current_config = load_config() + + # Merge new config into current config + current_config.update(new_config) + + with config_path.open("w", encoding="utf-8") as f: + f.write("# yt-study Configuration\n") + f.write("# Generated by yt-study setup wizard\n\n") + + # Sort keys for consistent output, prioritize critical ones + priority_keys = ["DEFAULT_MODEL", "OUTPUT_DIR", "MAX_CONCURRENT_VIDEOS"] + for key in priority_keys: + if key in current_config: + f.write(f"{key}={current_config[key]}\n") + + # Write remaining keys + for key, value in sorted(current_config.items()): + if key not in priority_keys: + f.write(f"{key}={value}\n") + + console.print( + f"\n[green]✓[/green] Configuration saved to: [cyan]{config_path}[/cyan]" + ) + + +def get_available_models() -> dict[str, list[str]]: + """Fetch available models from LiteLLM.""" + try: + # Lazy import to avoid slow startup if not running setup + from litellm import model_list + + # Group models by provider + provider_models: dict[str, list[str]] = {} + + for model in model_list: + # Determine provider from model name + provider = None + model_lower = model.lower() + + for prov_key, prov_config in PROVIDER_CONFIG.items(): + if any(keyword in model_lower for keyword in prov_config["keywords"]): + provider = prov_key + break + + if provider: + if provider not in provider_models: + provider_models[provider] = [] + provider_models[provider].append(model) + + # Sort models within each provider + for provider in provider_models: + provider_models[provider] = sorted(set(provider_models[provider])) + + return provider_models + + except Exception as e: + console.print(f"[yellow]⚠ Could not fetch models from LiteLLM: {e}[/yellow]") + console.print("[yellow]Using fallback model list...[/yellow]") + + # Fallback to curated list + return { + "gemini": [ + "gemini/gemini-2.0-flash-exp", + "gemini/gemini-2.0-flash", + "gemini/gemini-1.5-pro", + "gemini/gemini-1.5-flash", + ], + "openai": [ + "gpt-4o", + "gpt-4o-mini", + "gpt-4-turbo", + "o1", + "o1-mini", + ], + "anthropic": [ + "anthropic/claude-3-5-sonnet-20241022", + "anthropic/claude-3-5-haiku-20241022", + "anthropic/claude-3-opus-20240229", + ], + "groq": [ + "groq/llama-3.3-70b-versatile", + "groq/llama-3.1-8b-instant", + "groq/mixtral-8x7b-32768", + ], + "xai": [ + "xai/grok-2-latest", + "xai/grok-2-vision-latest", + ], + } + + +def select_provider(available_models: dict[str, list[str]]) -> str: + """Interactive provider selection.""" + console.print("\n[bold cyan]Select LLM Provider:[/bold cyan]\n") + + # Create table of providers + table = Table(show_header=True, header_style="bold magenta") + table.add_column("#", style="dim", width=4) + table.add_column("Provider", style="cyan") + table.add_column("Models Available", style="dim") + + providers_list = [] + for prov_key, _prov_config in PROVIDER_CONFIG.items(): + if prov_key in available_models: + providers_list.append(prov_key) + + for i, provider_key in enumerate(providers_list, 1): + config_data = PROVIDER_CONFIG[provider_key] + model_count = len(available_models.get(provider_key, [])) + table.add_row(str(i), config_data["name"], f"{model_count} models") + + console.print(table) + console.print(f"\n[dim]Total providers: {len(providers_list)}[/dim]") + + while True: + choice = Prompt.ask( + "\nSelect provider", + choices=[str(i) for i in range(1, len(providers_list) + 1)], + ) + return providers_list[int(choice) - 1] + + +def select_model(provider_key: str, available_models: dict[str, list[str]]) -> str: + """Interactive model selection.""" + provider_config = PROVIDER_CONFIG[provider_key] + models = available_models.get(provider_key, []) + + if not models: + console.print(f"[yellow]No models found for {provider_config['name']}[/yellow]") + return f"{provider_key}/default" + + console.print(f"\n[bold cyan]Select {provider_config['name']} Model:[/bold cyan]\n") + console.print(f"[dim]Showing {len(models)} available models[/dim]\n") + + # Show models in pages of 20 + page_size = 20 + current_page = 0 + + while True: + start_idx = current_page * page_size + end_idx = min(start_idx + page_size, len(models)) + page_models = models[start_idx:end_idx] + + # Create table of models + table = Table(show_header=True, header_style="bold magenta") + table.add_column("#", style="dim", width=4) + table.add_column("Model", style="green") + + for i, model in enumerate(page_models, start_idx + 1): + # Highlight recommended models + model_display = model + if "flash" in model.lower() or "mini" in model.lower(): + model_display = f"{model} [dim](fast)[/dim]" + elif ( + "pro" in model.lower() + or "turbo" in model.lower() + or "sonnet" in model.lower() + ): + model_display = f"{model} [dim](powerful)[/dim]" + + table.add_row(str(i), model_display) + + console.print(table) + + # Navigation info + total_pages = (len(models) + page_size - 1) // page_size + console.print( + f"\n[dim]Page {current_page + 1}/{total_pages} | " + f"Showing {start_idx + 1}-{end_idx} of {len(models)} " + f"models[/dim]" + ) + + if total_pages > 1: + console.print( + "[dim]Type 'n' for next page, 'p' for previous page, " + "or model number to select[/dim]" + ) + + choice = Prompt.ask("\nSelect model (or n/p for navigation)") + + if choice.lower() == "n" and current_page < total_pages - 1: + current_page += 1 + console.clear() + console.print( + f"\n[bold cyan]Select {provider_config['name']} Model:[/bold cyan]\n" + ) + continue + elif choice.lower() == "p" and current_page > 0: + current_page -= 1 + console.clear() + console.print( + f"\n[bold cyan]Select {provider_config['name']} Model:[/bold cyan]\n" + ) + continue + elif choice.isdigit() and 1 <= int(choice) <= len(models): + selected = models[int(choice) - 1] + + # Ensure Gemini models have correct prefix for Google AI Studio + if ( + provider_key == "gemini" + and not selected.startswith("gemini/") + and not selected.startswith("vertex_ai/") + ): + return f"gemini/{selected}" + + return selected + + +def get_api_key(provider_key: str, existing_key: str | None = None) -> str: + """Prompt for API key.""" + provider = PROVIDER_CONFIG[provider_key] + + console.print(f"\n[bold yellow]API Key Required:[/bold yellow] {provider['name']}") + console.print( + f"[dim]Get your API key from:[/dim] " + f"[link={provider['api_url']}]{provider['api_url']}[/link]\n" + ) + + if existing_key: + masked = ( + f"{existing_key[:8]}...{existing_key[-4:]}" + if len(existing_key) > 12 + else "***" + ) + use_existing = Confirm.ask(f"Use existing key ({masked})?", default=True) + if use_existing: + return existing_key + + while True: + api_key = Prompt.ask("Enter your API key", password=True) + if api_key and len(api_key) > 10: # Basic validation + return api_key + console.print("[red]Invalid API key. Please try again.[/red]") + + +def run_setup_wizard(force: bool = False) -> dict[str, str]: + """Run interactive setup wizard.""" + console.print( + Panel( + "[bold cyan]🎓 yt-study Setup Wizard[/bold cyan]\n\n" + "Configure your LLM provider and API keys\n" + "[dim]Powered by LiteLLM - 400+ models supported[/dim]", + border_style="cyan", + expand=False, + ) + ) + + # Load existing config + current_config = load_config() + + if current_config and not force: + console.print("\n[yellow]Existing configuration found.[/yellow]") + reconfigure = Confirm.ask("Do you want to reconfigure?", default=False) + if not reconfigure: + console.print("[green]Using existing configuration.[/green]") + return current_config + + # Fetch available models from LiteLLM + console.print("\n[cyan]Fetching available models from LiteLLM...[/cyan]") + available_models = get_available_models() + console.print( + f"[green]✓ Found {sum(len(m) for m in available_models.values())} " + f"models across {len(available_models)} providers[/green]" + ) + + # Select provider + provider_key = select_provider(available_models) + + # Select model + model = select_model(provider_key, available_models) + + # Get API key + provider_info = PROVIDER_CONFIG[provider_key] + existing_key = current_config.get(provider_info["env_var"]) + api_key = get_api_key(provider_key, existing_key) + + # Optional: Configure output directory + console.print("\n[bold cyan]Output Directory:[/bold cyan]") + default_output = str(Path.cwd() / "output") + # If output dir already in config, use it as default + if "OUTPUT_DIR" in current_config: + default_output = current_config["OUTPUT_DIR"] + + output_dir = Prompt.ask("Where should notes be saved?", default=default_output) + + # Optional: Configure concurrency + console.print("\n[bold cyan]Concurrency:[/bold cyan]") + default_concurrency = current_config.get("MAX_CONCURRENT_VIDEOS", "5") + concurrency = Prompt.ask( + "Max concurrent videos to process?", default=default_concurrency + ) + + # Build config updates + new_config = { + "DEFAULT_MODEL": model, + provider_info["env_var"]: api_key, + "OUTPUT_DIR": output_dir, + "MAX_CONCURRENT_VIDEOS": concurrency, + } + + # Save configuration (merging with existing) + save_config(new_config) + + console.print("\n[bold green]✓ Setup complete![/bold green]") + console.print( + Panel( + f"[dim]Selected model:[/dim] [cyan]{model}[/cyan]\n" + f"[dim]Configuration saved to:[/dim] [cyan]{get_config_path()}[/cyan]\n\n" + "[bold]Next Steps:[/bold]\n" + 'Run: [green]yt-study process "URL"[/green]', + title="🎉 Ready to go", + border_style="green", + ) + ) + + # Return merged config + current_config.update(new_config) + return current_config diff --git a/src/yt_study/core/youtube/__init__.py b/src/yt_study/core/youtube/__init__.py new file mode 100644 index 0000000..c09344d --- /dev/null +++ b/src/yt_study/core/youtube/__init__.py @@ -0,0 +1 @@ +"""YouTube module for handling video URLs, playlists, and transcripts.""" diff --git a/src/yt_study/core/youtube/metadata.py b/src/yt_study/core/youtube/metadata.py new file mode 100644 index 0000000..14f2bca --- /dev/null +++ b/src/yt_study/core/youtube/metadata.py @@ -0,0 +1,163 @@ +"""Video metadata extraction using pytubefix.""" + +import logging +from dataclasses import dataclass +from typing import Any + +from pytubefix import Playlist, YouTube +from rich.console import Console + + +console = Console() +logger = logging.getLogger(__name__) + + +@dataclass +class VideoChapter: + """ + A video chapter with title and time range. + + Attributes: + title: Chapter title. + start_seconds: Start time in seconds. + end_seconds: End time in seconds (None for the last chapter). + """ + + title: str + start_seconds: int + end_seconds: int | None = None + + +def get_video_chapters(video_id: str) -> list[VideoChapter]: + """ + Get chapters from a YouTube video. + + Note: This function performs blocking network I/O. + + Args: + video_id: YouTube video ID. + + Returns: + List of VideoChapter objects, empty if no chapters found. + """ + try: + url = f"https://www.youtube.com/watch?v={video_id}" + yt = YouTube(url) + + # Access chapters if available + # pytubefix properties trigger network calls + if hasattr(yt, "chapters") and yt.chapters: + chapters: list[VideoChapter] = [] + chapter_data = yt.chapters + + for i, chapter in enumerate(chapter_data): + # Handle pytubefix chapter object structure (dict or object) + start_time = _get_attr_or_item(chapter, "start_seconds", 0) + title = _get_attr_or_item(chapter, "title", f"Chapter {i + 1}") + + # Calculate end time (start of next chapter or None for last) + end_time = None + if i < len(chapter_data) - 1: + next_chapter = chapter_data[i + 1] + end_time = _get_attr_or_item(next_chapter, "start_seconds", None) + + chapters.append( + VideoChapter( + title=str(title), + start_seconds=int(start_time), + end_seconds=int(end_time) if end_time is not None else None, + ) + ) + + return chapters + + except Exception as e: + logger.debug(f"Could not fetch chapters for {video_id}: {e}") + + return [] + + +def get_video_title(video_id: str) -> str: + """ + Get the title of a YouTube video. + + Note: This function performs blocking network I/O. + + Args: + video_id: YouTube video ID. + + Returns: + Video title, or video ID if title cannot be fetched. + """ + try: + url = f"https://www.youtube.com/watch?v={video_id}" + yt = YouTube(url) + title = yt.title + + if title: + return str(title) + + except Exception as e: + logger.warning(f"Could not fetch title for {video_id}: {e}") + + # Fallback to video ID + return video_id + + +def get_video_duration(video_id: str) -> int: + """ + Get video duration in seconds. + + Note: This function performs blocking network I/O. + + Args: + video_id: YouTube video ID. + + Returns: + Duration in seconds, 0 if cannot be fetched. + """ + try: + url = f"https://www.youtube.com/watch?v={video_id}" + yt = YouTube(url) + return int(yt.length) + except Exception as e: + logger.warning(f"Could not fetch duration for {video_id}: {e}") + return 0 + + +def get_playlist_info(playlist_id: str) -> tuple[str, int]: + """ + Get playlist title and video count. + + Note: This function performs blocking network I/O. + + Args: + playlist_id: YouTube playlist ID. + + Returns: + Tuple of (title, video_count). + """ + try: + url = f"https://www.youtube.com/playlist?list={playlist_id}" + playlist = Playlist(url) + + # Pytube's title might fail if playlist is private/invalid + title = getattr(playlist, "title", f"playlist_{playlist_id}") + + # Getting length requires fetching the page + # list(playlist.video_urls) is robust but slow for huge playlists + # For metadata, it's acceptable. + count = len(list(playlist.video_urls)) + + return str(title), count + + except Exception as e: + logger.warning(f"Could not fetch playlist info: {e}") + return f"playlist_{playlist_id}", 0 + + +def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any: + """Helper to get value from object attribute or dict key.""" + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) diff --git a/src/yt_study/core/youtube/parser.py b/src/yt_study/core/youtube/parser.py new file mode 100644 index 0000000..3830189 --- /dev/null +++ b/src/yt_study/core/youtube/parser.py @@ -0,0 +1,117 @@ +"""YouTube URL parser for video and playlist detection.""" + +import re +from dataclasses import dataclass +from urllib.parse import parse_qs, urlparse + + +@dataclass +class ParsedURL: + """ + Parsed YouTube URL information. + + Attributes: + url_type: Type of the URL ('video' or 'playlist'). + video_id: Extracted video ID (if present). + playlist_id: Extracted playlist ID (if present). + """ + + url_type: str # 'video' or 'playlist' + video_id: str | None = None + playlist_id: str | None = None + + +def extract_video_id(url: str) -> str | None: + """ + Extract video ID from various YouTube URL formats. + + Supports: + - Standard: https://www.youtube.com/watch?v=VIDEO_ID + - Short: https://youtu.be/VIDEO_ID + - Embed: https://www.youtube.com/embed/VIDEO_ID + - V-path: https://www.youtube.com/v/VIDEO_ID + - Shorts: https://www.youtube.com/shorts/VIDEO_ID + + Args: + url: The YouTube URL string. + + Returns: + The 11-character video ID if found, else None. + """ + # Common patterns for YouTube Video IDs (11 chars, alphanumeric + _ -) + patterns = [ + r"(?:v=|\/)([0-9A-Za-z_-]{11}).*", + r"youtu\.be\/([0-9A-Za-z_-]{11})", + r"embed\/([0-9A-Za-z_-]{11})", + r"shorts\/([0-9A-Za-z_-]{11})", + ] + + for pattern in patterns: + match = re.search(pattern, url) + if match: + return match.group(1) + + return None + + +def extract_playlist_id(url: str) -> str | None: + """ + Extract playlist ID from YouTube playlist URL. + + Supports: + - https://www.youtube.com/playlist?list=PLAYLIST_ID + - https://www.youtube.com/watch?v=VIDEO_ID&list=PLAYLIST_ID + + Args: + url: The YouTube URL string. + + Returns: + The playlist ID if found, else None. + """ + try: + parsed = urlparse(url) + query_params = parse_qs(parsed.query) + + if "list" in query_params: + return query_params["list"][0] + except Exception: + # Fail gracefully on malformed URLs + pass + + return None + + +def parse_youtube_url(url: str) -> ParsedURL: + """ + Parse a YouTube URL and determine if it's a video or playlist. + + Prioritizes playlist ID if 'list' parameter is present, + but also extracts video ID if available (e.g. watching a playlist). + + Args: + url: YouTube URL (video or playlist) + + Returns: + ParsedURL object with url_type and relevant IDs + + Raises: + ValueError: If URL is not a valid YouTube URL (neither video nor playlist) + """ + if not url or not isinstance(url, str): + raise ValueError("URL must be a non-empty string") + + # Check for playlist first + playlist_id = extract_playlist_id(url) + if playlist_id: + # It's a playlist URL + video_id = extract_video_id(url) # Might have both + return ParsedURL( + url_type="playlist", playlist_id=playlist_id, video_id=video_id + ) + + # Check for video + video_id = extract_video_id(url) + if video_id: + return ParsedURL(url_type="video", video_id=video_id) + + raise ValueError(f"Invalid YouTube URL: {url}") diff --git a/src/yt_study/core/youtube/playlist.py b/src/yt_study/core/youtube/playlist.py new file mode 100644 index 0000000..706867a --- /dev/null +++ b/src/yt_study/core/youtube/playlist.py @@ -0,0 +1,97 @@ +"""Playlist video extraction using pytubefix.""" + +import asyncio +import logging + +from pytubefix import Playlist +from rich.console import Console + + +console = Console() +logger = logging.getLogger(__name__) + + +class PlaylistError(Exception): + """Exception raised for playlist-related errors.""" + + pass + + +async def extract_playlist_videos(playlist_id: str) -> list[str]: + """ + Extract all video IDs from a YouTube playlist with retry logic. + + This function handles the blocking network calls of pytubefix by offloading + them to a separate thread, ensuring the asyncio event loop remains responsive. + + Args: + playlist_id: YouTube playlist ID. + + Returns: + List of video IDs. + + Raises: + PlaylistError: If playlist cannot be accessed after retries. + """ + max_retries = 3 + last_error = None + + for attempt in range(max_retries): + try: + # Wrap blocking pytubefix logic in a thread + video_ids = await asyncio.to_thread(_extract_sync, playlist_id, attempt) + + if not video_ids: + # Should have been raised in _extract_sync if empty, but double check + raise ValueError( + f"No videos found in playlist (Attempt {attempt + 1}/{max_retries})" + ) + + logger.info(f"Found {len(video_ids)} videos in playlist") + return video_ids + + except Exception as e: + last_error = e + logger.warning(f"Playlist extraction attempt {attempt + 1} failed: {e}") + if attempt < max_retries - 1: + wait_time = 2**attempt # Exponential backoff + logger.warning(f"Retrying in {wait_time}s...") + await asyncio.sleep(wait_time) + + logger.error( + f"Failed to extract playlist videos after {max_retries} attempts: {last_error}" + ) + raise PlaylistError(f"Could not access playlist {playlist_id}: {str(last_error)}") + + +def _extract_sync(playlist_id: str, attempt: int) -> list[str]: + """Blocking helper to extract videos using pytubefix.""" + playlist_url = f"https://www.youtube.com/playlist?list={playlist_id}" + playlist = Playlist(playlist_url) + + # Access playlist title to trigger loading + try: + title = playlist.title + if attempt == 0: + logger.info(f"Found playlist: {title}") + except Exception: + # Title fetch might fail but video extraction might still work + logger.warning(f"Could not fetch playlist title on attempt {attempt + 1}") + + video_ids = [] + + # Extract video IDs from URLs (waits for internal generator) + # This loop triggers network requests + for url in playlist.video_urls: + if "v=" in url: + try: + # Robust ID extraction + video_id = url.split("v=")[1].split("&")[0] + video_ids.append(video_id) + except IndexError: + continue + + if not video_ids: + raise ValueError("No videos found in playlist") + + return video_ids diff --git a/src/yt_study/core/youtube/transcript.py b/src/yt_study/core/youtube/transcript.py new file mode 100644 index 0000000..1f8a2ac --- /dev/null +++ b/src/yt_study/core/youtube/transcript.py @@ -0,0 +1,297 @@ +"""Transcript fetching with multi-language support.""" + +import asyncio +import logging +from dataclasses import dataclass +from typing import Any + +from rich.console import Console +from youtube_transcript_api import YouTubeTranscriptApi +from youtube_transcript_api._errors import ( + IpBlocked, + NoTranscriptFound, + RequestBlocked, + TranscriptsDisabled, + VideoUnavailable, +) + +from .metadata import VideoChapter + + +console = Console() +logger = logging.getLogger(__name__) + + +@dataclass +class TranscriptSegment: + """ + A segment of transcript text with timing. + + Attributes: + text: The spoken text. + start: Start time in seconds. + duration: Duration of the segment in seconds. + """ + + text: str + start: float + duration: float + + +@dataclass +class VideoTranscript: + """ + Complete transcript for a video. + + Attributes: + video_id: The YouTube video ID. + segments: List of transcript segments. + language: Language name (e.g., 'English'). + language_code: Language code (e.g., 'en'). + is_generated: Whether the transcript is auto-generated. + """ + + video_id: str + segments: list[TranscriptSegment] + language: str + language_code: str + is_generated: bool + + def to_text(self) -> str: + """Convert transcript segments to continuous text.""" + return " ".join(segment.text for segment in self.segments) + + +class TranscriptError(Exception): + """Exception raised for transcript-related errors.""" + + pass + + +class YouTubeIPBlockError(TranscriptError): + """Exception raised when YouTube blocks IP.""" + + pass + + +async def fetch_transcript( + video_id: str, languages: list[str] | None = None +) -> VideoTranscript: + """ + Fetch transcript for a YouTube video with language fallback and retry logic. + + Priority: + 1. Manual transcript in preferred language + 2. Auto-generated transcript in preferred language + 3. Manual transcript in any available language + 4. Auto-generated transcript in any available language + 5. Translated transcript to English + + Args: + video_id: YouTube video ID. + languages: Preferred language codes (e.g., ['en', 'hi']). Defaults to ['en']. + + Returns: + VideoTranscript object. + + Raises: + TranscriptError: If no transcript is available. + """ + if languages is None: + languages = ["en"] + + retries = 3 + + for attempt in range(retries): + try: + # Wrap blocking YouTubeTranscriptApi calls in a thread + # This is critical to prevent blocking the asyncio event loop + # during concurrency + raw_transcript, transcript_meta, log_msg = await asyncio.to_thread( + _fetch_sync, video_id, languages + ) + + logger.info(log_msg) + + # Convert to our format + segments = [] + for segment in raw_transcript: + # Handle both dict (standard) and object + # (FetchedTranscriptSnippet) formats + if isinstance(segment, dict): + text = segment.get("text", "") + start = segment.get("start", 0.0) + duration = segment.get("duration", 0.0) + else: + # Fallback for object-based returns + text = getattr(segment, "text", "") + start = getattr(segment, "start", 0.0) + duration = getattr(segment, "duration", 0.0) + + segments.append( + TranscriptSegment( + text=text, start=float(start), duration=float(duration) + ) + ) + + return VideoTranscript( + video_id=video_id, + segments=segments, + language=transcript_meta.language, + language_code=transcript_meta.language_code, + is_generated=transcript_meta.is_generated, + ) + + except (TranscriptsDisabled, VideoUnavailable) as e: + # Fatal errors, do not retry + logger.error(f"Transcript unavailable for {video_id}: {e}") + raise TranscriptError( + f"Transcripts are disabled or video is unavailable: {video_id}" + ) from e + + except (TranscriptError, NoTranscriptFound): + # Already handled or strictly not found, do not retry + raise + + except (IpBlocked, RequestBlocked) as e: + # Specifically handle IP blocking + logger.error(f"YouTube IP Block detected for {video_id}") + raise YouTubeIPBlockError( + "YouTube is blocking requests from your IP. " + "Please try using a VPN, proxies, or wait a while." + ) from e + + except Exception as e: + err_str = str(e) + if "blocking requests from your IP" in err_str: + logger.error(f"YouTube IP Block detected for {video_id}: {e}") + raise YouTubeIPBlockError( + "YouTube is blocking requests from your IP. " + "Please try using a VPN, proxies, or wait a while." + ) from e + + if attempt < retries - 1: + wait_time = 2**attempt + logger.warning( + f"Transcript fetch failed ({str(e)}), retrying in {wait_time}s..." + ) + await asyncio.sleep(wait_time) + else: + logger.error(f"Failed to fetch transcript for {video_id}: {e}") + raise TranscriptError(f"Could not fetch transcript: {str(e)}") from e + + # Should be unreachable due to raise in loop + raise TranscriptError(f"Failed to fetch transcript for {video_id}") + + +def _fetch_sync(video_id: str, languages: list[str]) -> tuple[Any, Any, str]: + """Blocking helper to interact with YouTubeTranscriptApi.""" + ytt_api = YouTubeTranscriptApi() + + # List all available transcripts + # This list call can fail with TranscriptsDisabled or VideoUnavailable + transcript_list = ytt_api.list(video_id) + + transcript = None + found_msg = "" + + # Strategy 1: Find manual transcript in preferred language + try: + transcript = transcript_list.find_manually_created_transcript(languages) + found_msg = f"Found manual transcript: {transcript.language}" + except NoTranscriptFound: + pass + + # Strategy 2: Try auto-generated in preferred language + if not transcript: + try: + transcript = transcript_list.find_generated_transcript(languages) + found_msg = f"Using auto-generated transcript: {transcript.language}" + except NoTranscriptFound: + pass + + # Strategy 3: Try any manual transcript + if not transcript: + try: + # Get all language codes available + all_codes = [t.language_code for t in transcript_list] + transcript = transcript_list.find_manually_created_transcript(all_codes) + found_msg = f"Using manual transcript in {transcript.language}" + except NoTranscriptFound: + pass + + # Strategy 4: Last resort - try any available transcript and translate if needed + if not transcript: + try: + # list(transcript_list) returns iterable of Transcript objects + available = list(transcript_list) + if not available: + raise NoTranscriptFound(video_id, languages, []) + + first_available = available[0] + + # Try to translate to English if not English already and requested + if "en" in languages and first_available.language_code != "en": + if first_available.is_translatable: + transcript = first_available.translate("en") + found_msg = f"Translated {first_available.language} -> English" + else: + transcript = first_available + found_msg = ( + f"Using {transcript.language} (translation not available)" + ) + else: + transcript = first_available + found_msg = f"Using {transcript.language}" + + except Exception as e: + # If we really can't find anything + if isinstance(e, NoTranscriptFound): + raise + raise TranscriptError(f"No usable transcript found: {e}") from e + + # Fetch the actual transcript data + raw_transcript = transcript.fetch() + return raw_transcript, transcript, found_msg + + +def split_transcript_by_chapters( + transcript: VideoTranscript, chapters: list[VideoChapter] +) -> dict[str, str]: + """ + Split a video transcript by chapters. + + Args: + transcript: VideoTranscript object. + chapters: List of VideoChapter objects. + + Returns: + Dictionary mapping chapter titles to their transcript text. + """ + chapter_transcripts = {} + + for chapter in chapters: + # Filter segments for this chapter + chapter_segments = [] + + for segment in transcript.segments: + segment_start = segment.start + + # Check if segment start is within chapter range + if chapter.end_seconds is None: + # Last chapter - include everything after start + if segment_start >= chapter.start_seconds: + chapter_segments.append(segment.text) + else: + # Middle chapters - include if in range + if ( + segment_start >= chapter.start_seconds + and segment_start < chapter.end_seconds + ): + chapter_segments.append(segment.text) + + # Combine segments for this chapter + chapter_text = " ".join(chapter_segments) + chapter_transcripts[chapter.title] = chapter_text + + return chapter_transcripts diff --git a/src/yt_study/llm/__init__.py b/src/yt_study/llm/__init__.py index 2d5e18e..258ecae 100644 --- a/src/yt_study/llm/__init__.py +++ b/src/yt_study/llm/__init__.py @@ -1 +1,3 @@ -"""LLM module for multi-provider support and content generation.""" +"""Backward compatibility - llm package moved to core.llm.""" + +from yt_study.core.llm import * # noqa: F401, F403 diff --git a/src/yt_study/llm/generator.py b/src/yt_study/llm/generator.py index 0dfee2c..37e9917 100644 --- a/src/yt_study/llm/generator.py +++ b/src/yt_study/llm/generator.py @@ -1,355 +1,3 @@ -"""Study material generator with chunking and combining logic.""" +"""Backward compatibility - llm.generator moved to core.llm.generator.""" -import logging - -from litellm import token_counter -from rich.console import Console -from rich.progress import Progress, TaskID - -from ..config import config -from ..prompts.chapter_notes import ( - get_chapter_prompt, - get_combine_chapters_prompt, -) -from ..prompts.study_notes import ( - SYSTEM_PROMPT, - get_chunk_prompt, - get_combine_prompt, - get_single_pass_prompt, -) -from .providers import LLMProvider - - -# Re-use system prompt for now -CHAPTER_SYSTEM_PROMPT = SYSTEM_PROMPT - -console = Console() -logger = logging.getLogger(__name__) - - -class StudyMaterialGenerator: - """ - Generate study materials from transcripts using LLM. - - Handles token counting, text chunking, and recursive summarization/generation. - """ - - def __init__( - self, - provider: LLMProvider, - temperature: float = 0.7, - max_tokens: int | None = None, - ): - """ - Initialize generator. - - Args: - provider: LLM provider instance. - temperature: LLM response temperature. - max_tokens: Maximum tokens for LLM responses. - """ - self.provider = provider - self.temperature = temperature - self.max_tokens = max_tokens - - def _count_tokens(self, text: str) -> int: - """Count tokens in text using model-specific tokenizer.""" - # Note: token_counter might do network calls for some models or use - # local libraries (tiktoken). For efficiency, we assume it's fast. - try: - count = token_counter(model=self.provider.model, text=text) - return int(count) if count is not None else len(text) // 4 - except Exception: - # Fallback estimation if tokenizer fails (approx 4 chars per token) - return len(text) // 4 - - def _chunk_transcript(self, transcript: str) -> list[str]: - """ - Split transcript into chunks with overlap. - - Uses recursive chunking strategy: - - Target size: Defined in config (default 4000 tokens) - - Overlap: Defined in config (default 200 tokens) - - Priority: Sentence boundaries > Newlines > Words > Hard char limit - - Args: - transcript: The full transcript text. - - Returns: - List of text chunks. - """ - token_count = self._count_tokens(transcript) - - # Fast path: Return single chunk if within limits - if token_count <= config.chunk_size: - return [transcript] - - logger.info( - f"Transcript too long ({token_count:,} tokens), performing chunking..." - ) - - chunks: list[str] = [] - - # Strategy 1: Split by sentences - sentences = transcript.split(". ") - - # Strategy 2: Split by newlines if sentences fail - if len(sentences) < 2 and token_count > config.chunk_size: - sentences = transcript.split("\n") - - # Strategy 3: Split by spaces if newlines fail - if len(sentences) < 2: - sentences = transcript.split(" ") - - current_chunk: list[str] = [] - current_tokens = 0 - - for sentence in sentences: - sentence = sentence.strip() - if not sentence: - continue - - # Re-add delimiter for estimation (approximate) - # We assume '. ' was the delimiter for simplicity, logic holds - # for others mostly as we care about token count - term = sentence + ". " - term_tokens = self._count_tokens(term) - - # Handle edge case: Single sentence/segment is larger than chunk_size - if term_tokens > config.chunk_size: - # 1. Flush current buffer - if current_chunk: - chunks.append(" ".join(current_chunk)) - current_chunk = [] - current_tokens = 0 - - # 2. Hard split the massive segment - # Estimate char limit based on token size (conservative 3 chars/token) - char_limit = config.chunk_size * 3 - for i in range(0, len(sentence), char_limit): - sub_part = sentence[i : i + char_limit] - chunks.append(sub_part) - continue - - # Standard accumulation - if current_tokens + term_tokens > config.chunk_size: - # Chunk is full. Commit it. - if current_chunk: - chunks.append(" ".join(current_chunk)) - - # Create overlap for next chunk - overlap_chunk: list[str] = [] - overlap_tokens = 0 - - # Take sentences from the end of current_chunk until overlap limit - for prev_sent in reversed(current_chunk): - prev_tokens = self._count_tokens(prev_sent) - if overlap_tokens + prev_tokens <= config.chunk_overlap: - overlap_chunk.insert(0, prev_sent) - overlap_tokens += prev_tokens - else: - break - - current_chunk = overlap_chunk + [sentence] - current_tokens = self._count_tokens(" ".join(current_chunk)) - else: - # Should be unreachable due to check above, but safe fallback - current_chunk.append(sentence) - current_tokens += term_tokens - else: - current_chunk.append(sentence) - current_tokens += term_tokens - - # Add remaining chunk - if current_chunk: - chunks.append(" ".join(current_chunk)) - - logger.info(f"Created {len(chunks)} chunks") - return chunks - - def _update_status( - self, - progress: Progress | None, - task_id: TaskID | None, - video_title: str, - message: str, - ) -> None: - """Safe helper to update progress bar or log message.""" - if progress and task_id is not None: - short_title = ( - (video_title[:20] + "...") if len(video_title) > 20 else video_title - ) - # We assume the layout uses 'description' for the status text - progress.update( - task_id, description=f"[yellow]{short_title}[/yellow]: {message}" - ) - else: - logger.info(f"{video_title}: {message}") - - async def generate_study_notes( - self, - transcript: str, - video_title: str = "Video", - progress: Progress | None = None, - task_id: TaskID | None = None, - ) -> str: - """ - Generate study notes from transcript. - - Args: - transcript: Full video transcript text. - video_title: Video title for progress display. - progress: Optional existing progress bar instance. - task_id: Optional task ID for updating progress. - - Returns: - Complete study notes in Markdown format. - """ - chunks = self._chunk_transcript(transcript) - - # Single chunk - generate directly - if len(chunks) == 1: - self._update_status(progress, task_id, video_title, "Generating notes...") - - notes = await self.provider.generate( - system_prompt=SYSTEM_PROMPT, - user_prompt=get_single_pass_prompt(transcript), - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - - if not progress: - logger.info(f"Generated notes for {video_title}") - return notes - - # Multiple chunks - generate for each, then combine - self._update_status( - progress, - task_id, - video_title, - f"Generating notes for {len(chunks)} chunks...", - ) - - chunk_notes = [] - - for i, chunk in enumerate(chunks, 1): - msg = f"Chunk {i}/{len(chunks)} (Generating)" - self._update_status(progress, task_id, video_title, msg) - - note = await self.provider.generate( - system_prompt=SYSTEM_PROMPT, - user_prompt=get_chunk_prompt(chunk), - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - chunk_notes.append(note) - - self._update_status( - progress, - task_id, - video_title, - f"Combining {len(chunk_notes)} chunk notes...", - ) - - final_notes = await self.provider.generate( - system_prompt=SYSTEM_PROMPT, - user_prompt=get_combine_prompt(chunk_notes), - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - - if not progress: - logger.info(f"Completed notes for {video_title}") - - return final_notes - - async def generate_single_chapter_notes( - self, - chapter_title: str, - chapter_text: str, - ) -> str: - """ - Generate study notes for a single chapter. - - Args: - chapter_title: Title of the chapter. - chapter_text: Transcript text for the chapter. - - Returns: - Study notes for the chapter. - """ - notes = await self.provider.generate( - system_prompt=CHAPTER_SYSTEM_PROMPT, - user_prompt=get_chapter_prompt(chapter_title, chapter_text), - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - return notes - - async def generate_chapter_based_notes( - self, - chapter_transcripts: dict[str, str], - video_title: str = "Video", - progress: Progress | None = None, - task_id: TaskID | None = None, - ) -> str: - """ - Generate study notes using chapter-based approach. - - Args: - chapter_transcripts: Dictionary mapping chapter titles to transcript text. - video_title: Video title for display. - progress: Optional existing progress bar instance. - task_id: Optional task ID for updating progress. - - Returns: - Complete study notes organized by chapters. - """ - # Imports are already at top-level or can be moved up, but let's - # fix the specific issue. Previously we did lazy import inside - # function which caused issues - - self._update_status( - progress, - task_id, - video_title, - f"Generating notes for {len(chapter_transcripts)} chapters...", - ) - - chapter_notes = {} - total_chapters = len(chapter_transcripts) - - for i, (chapter_title, chapter_text) in enumerate( - chapter_transcripts.items(), 1 - ): - msg = f"Chapter {i}/{total_chapters}: {chapter_title[:20]}..." - self._update_status(progress, task_id, video_title, msg) - - # If a chapter is huge, we might need recursive chunking here too. - # For now, we assume chapters are reasonably sized or the model - # can handle ~100k context. Future improvement: Check token - # count of chapter_text and recurse if needed. - - notes = await self.provider.generate( - system_prompt=CHAPTER_SYSTEM_PROMPT, - user_prompt=get_chapter_prompt(chapter_title, chapter_text), - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - chapter_notes[chapter_title] = notes - - self._update_status( - progress, task_id, video_title, "Combining chapter notes..." - ) - - final_notes = await self.provider.generate( - system_prompt=CHAPTER_SYSTEM_PROMPT, - user_prompt=get_combine_chapters_prompt(chapter_notes), - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - - if not progress: - logger.info(f"Completed chapter-based notes for {video_title}") - - return final_notes +from yt_study.core.llm.generator import * # noqa: F401, F403 diff --git a/src/yt_study/llm/providers.py b/src/yt_study/llm/providers.py index 625e45d..0e1eeb2 100644 --- a/src/yt_study/llm/providers.py +++ b/src/yt_study/llm/providers.py @@ -1,152 +1,3 @@ -"""LLM provider configuration using LiteLLM.""" +"""Backward compatibility - llm.providers moved to core.llm.providers.""" -import logging -import os -from typing import Any - -from litellm import acompletion -from rich.console import Console - -from ..config import config - - -console = Console() -logger = logging.getLogger(__name__) - - -class LLMGenerationError(Exception): - """Exception raised when LLM generation fails.""" - - pass - - -class LLMProvider: - """ - LLM provider interface using LiteLLM. - - Handles API key verification and text generation with retries. - """ - - def __init__(self, model: str = "gemini/gemini-2.0-flash"): - """ - Initialize LLM provider. - - Args: - model: LiteLLM-compatible model string (e.g., 'gemini/gemini-2.0-flash'). - """ - self.model = model - self._validate_config() - - def _validate_config(self) -> None: - """ - Verify that the necessary API key for the selected model is set. - Logs a warning if missing. - """ - # We rely on Config to check environment variables, - # but we can double check here for the specific model - key_name = config.get_api_key_name_for_model(self.model) - if key_name: - if not os.getenv(key_name): - logger.warning( - f"API Key for model '{self.model}' ({key_name}) not found " - "in environment. Generation may fail." - ) - else: - # If we can't map the model to a specific key (unknown provider), - # we assume the user knows what they are doing or it doesn't need - # one (e.g. ollama) - logger.debug(f"No specific API key mapping found for model: {self.model}") - - async def generate( - self, - system_prompt: str, - user_prompt: str, - temperature: float = 0.7, - max_tokens: int | None = None, - ) -> str: - """ - Generate text using the configured LLM. - - Args: - system_prompt: System/instruction prompt. - user_prompt: User query/content. - temperature: Sampling temperature (0.0 to 1.0). - max_tokens: Maximum tokens to generate (None for model default). - - Returns: - Generated text content. - - Raises: - LLMGenerationError: If generation fails after retries. - """ - try: - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ] - - kwargs: dict[str, Any] = { - "model": self.model, - "messages": messages, - "temperature": temperature, - # LiteLLM handles exponential backoff for RateLimitError - "num_retries": 3, - } - - if max_tokens: - kwargs["max_tokens"] = max_tokens - - # LiteLLM's acompletion handles async requests to various providers - response = await acompletion(**kwargs) - - # safely extract content - if not response.choices or not response.choices[0].message.content: - raise LLMGenerationError("Received empty response from LLM provider") - - content = response.choices[0].message.content.strip() - return self._clean_content(content) - - except Exception as e: - logger.error(f"LLM generation failed with {self.model}: {e}", exc_info=True) - raise LLMGenerationError( - f"Failed to generate with {self.model}: {str(e)}" - ) from e - - def _clean_content(self, content: str) -> str: - """ - Remove markdown code block fencing if the LLM wraps the entire output in it. - - Args: - content: Raw LLM output. - - Returns: - Cleaned content string. - """ - # Check for triple backticks - if content.startswith("```"): - lines = content.splitlines() - # Need at least fence start, content, fence end - if len(lines) >= 2 and lines[0].strip().startswith("```"): - # If the first line is just a fence (with optional language), remove it - # Check if the last line is also a fence - if lines[-1].strip() == "```": - return "\n".join(lines[1:-1]).strip() - # Sometimes LLMs stop abruptly or formatting is weird; - # if it starts with fence, we strip the first line. - # If it ends with fence, strip that too. - return "\n".join(lines[1:]).strip().removesuffix("```").strip() - - return content - - -def get_provider(model: str = "gemini/gemini-2.0-flash") -> LLMProvider: - """ - Factory function to get an LLM provider instance. - - Args: - model: LiteLLM-compatible model string. - - Returns: - Configured LLMProvider instance. - """ - return LLMProvider(model=model) +from yt_study.core.llm.providers import * # noqa: F401, F403 diff --git a/src/yt_study/pipeline/__init__.py b/src/yt_study/pipeline/__init__.py index d4edac6..c3b37b9 100644 --- a/src/yt_study/pipeline/__init__.py +++ b/src/yt_study/pipeline/__init__.py @@ -1,6 +1,5 @@ -"""Pipeline orchestration module.""" +"""Backward compatibility - pipeline package moved to core.""" -from .orchestrator import PipelineOrchestrator +from yt_study.core.orchestrator import PipelineOrchestrator, sanitize_filename - -__all__ = ["PipelineOrchestrator"] +__all__ = ["PipelineOrchestrator", "sanitize_filename"] diff --git a/src/yt_study/pipeline/orchestrator.py b/src/yt_study/pipeline/orchestrator.py index c183699..7128c6b 100644 --- a/src/yt_study/pipeline/orchestrator.py +++ b/src/yt_study/pipeline/orchestrator.py @@ -1,534 +1,4 @@ -"""Main pipeline orchestrator with concurrent processing.""" +"""Backward compatibility - pipeline.orchestrator moved to core.orchestrator.""" -import asyncio -import logging -import re -from pathlib import Path - -from rich.console import Console -from rich.live import Live -from rich.panel import Panel -from rich.progress import Progress, TaskID -from rich.table import Table - -from ..config import config -from ..llm.generator import StudyMaterialGenerator -from ..llm.providers import get_provider -from ..ui.dashboard import PipelineDashboard -from ..youtube.metadata import ( - get_playlist_info, - get_video_chapters, - get_video_duration, - get_video_title, -) -from ..youtube.parser import parse_youtube_url -from ..youtube.playlist import extract_playlist_videos -from ..youtube.transcript import ( - YouTubeIPBlockError, - fetch_transcript, - split_transcript_by_chapters, -) - - -console = Console() -logger = logging.getLogger(__name__) - - -def sanitize_filename(name: str) -> str: - """ - Sanitize a string to be used as a filename. - - Args: - name: Raw filename string. - - Returns: - Sanitized string safe for file systems. - """ - # Remove or replace invalid characters - name = re.sub(r'[<>:"/\\|?*]', "", name) - # Replace multiple spaces with single space - name = re.sub(r"\s+", " ", name) - # Trim and limit length - name = name.strip()[:100] - return name if name else "untitled" - - -class PipelineOrchestrator: - """ - Orchestrates the end-to-end pipeline for video processing. - - Manages concurrency, error handling, and UI updates. - """ - - def __init__( - self, - model: str = "gemini/gemini-2.0-flash", - output_dir: Path | None = None, - languages: list[str] | None = None, - temperature: float | None = None, - max_tokens: int | None = None, - ): - """ - Initialize orchestrator. - - Args: - model: LLM model string. - output_dir: Output directory path. - languages: Preferred transcript languages. - temperature: LLM temperature (defaults to config.temperature). - max_tokens: Max tokens (defaults to config.max_tokens). - """ - self.model = model - self.output_dir = output_dir or config.default_output_dir - self.languages = languages or config.default_languages - self.temperature = ( - temperature if temperature is not None else config.temperature - ) - self.max_tokens = max_tokens if max_tokens is not None else config.max_tokens - self.provider = get_provider(model) - self.generator = StudyMaterialGenerator( - self.provider, - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - self.semaphore = asyncio.Semaphore(config.max_concurrent_videos) - - def validate_provider(self) -> bool: - """ - Validate that the API key for the selected provider is set. - - Returns: - True if valid (or warning logged), False if critical missing config. - """ - key_name = config.get_api_key_name_for_model(self.model) - - if key_name: - import os - - if not os.environ.get(key_name): - console.print( - f"\n[red bold]✗ Missing API Key for {self.model}[/red bold]" - ) - console.print( - f"[yellow]Expected environment variable: {key_name}[/yellow]" - ) - console.print( - "[dim]Please check your .env file or run:[/dim] " - "[cyan]yt-study setup[/cyan]\n" - ) - return False - - return True - - async def process_video( - self, - video_id: str, - output_path: Path, - progress: Progress | None = None, - task_id: TaskID | None = None, - video_title: str | None = None, - is_playlist: bool = False, - ) -> bool: - """ - Process a single video: fetch transcript and generate study notes. - - Args: - video_id: YouTube Video ID. - output_path: Destination path for the MD file. - progress: Rich Progress instance. - task_id: Rich TaskID. - video_title: Pre-fetched title (optional). - is_playlist: Whether this is part of a playlist (affects UI logging). - - Returns: - True on success, False on failure. - """ - async with self.semaphore: - local_task_id = task_id - - # If standalone (not part of worker pool), create a specific - # bar if requested - if is_playlist and progress and task_id is None: - display_title = (video_title or video_id)[:30] - local_task_id = progress.add_task( - description=f"[cyan]⏳ {display_title}... (Waiting)[/cyan]", - total=None, - ) - - try: - # 1. Fetch Metadata - if not video_title: - # Run in thread to avoid blocking - video_title = await asyncio.to_thread(get_video_title, video_id) - - # Fetch duration and chapters concurrently - duration, chapters = await asyncio.gather( - asyncio.to_thread(get_video_duration, video_id), - asyncio.to_thread(get_video_chapters, video_id), - ) - - title_display = (video_title or video_id)[:40] - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=f"[cyan]📥 {title_display}... (Transcript)[/cyan]", - ) - - # 2. Fetch Transcript - transcript_obj = await fetch_transcript(video_id, self.languages) - - # 3. Determine Generation Strategy - # Use chapters if video is long (>1h) and chapters exist - use_chapters = duration > 3600 and len(chapters) > 0 and not is_playlist - - if use_chapters: - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[cyan]📖 {title_display}... (Chapters)[/cyan]" - ), - ) - # else block removed as redundant - - # Split transcript - chapter_transcripts = split_transcript_by_chapters( - transcript_obj, chapters - ) - - # Create folder for chapter notes - safe_title = sanitize_filename(video_title) - output_folder = self.output_dir / safe_title - output_folder.mkdir(parents=True, exist_ok=True) - - # Generate chapter notes - # Fix: Iterate here and call generator for each chapter - # to save individually - - for i, (chap_title, chap_text) in enumerate( - chapter_transcripts.items(), 1 - ): - status_msg = f"Chapter {i}/{len(chapter_transcripts)}" - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[cyan]🤖 {title_display}... ({status_msg})[/cyan]" - ), - ) - - notes = await self.generator.generate_single_chapter_notes( - chapter_title=chap_title, - chapter_text=chap_text, - ) - - # Save individual chapter - safe_chapter = sanitize_filename(chap_title) - chapter_file = output_folder / f"{i:02d}_{safe_chapter}.md" - chapter_file.write_text(notes, encoding="utf-8") - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=f"[green]✓ {title_display} (Done)[/green]", - completed=True, - ) - - return True - - else: - # Single file generation - transcript_text = transcript_obj.to_text() - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[cyan]🤖 {title_display}... (Generating)[/cyan]" - ), - ) - - notes = await self.generator.generate_study_notes( - transcript_text, - video_title=title_display, - progress=progress, - task_id=local_task_id, - ) - - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(notes, encoding="utf-8") - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=f"[green]✓ {title_display} (Done)[/green]", - completed=True, - ) - - return True - - except Exception as e: - logger.error(f"Failed to process {video_id}: {e}") - - err_msg = str(e) - if isinstance(e, YouTubeIPBlockError) or ( - "blocking requests" in err_msg - ): - err_display = "[bold red]IP BLOCKED[/bold red]" - console.print( - Panel( - "[bold red]🚫 YouTube IP Block Detected[/bold red]\n\n" - "YouTube is limiting requests from your IP address.\n" - "[yellow]➤ Recommendation:[/yellow] Use a VPN or " - "wait ~1 hour.", - border_style="red", - ) - ) - else: - err_display = "(Failed)" - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[red]✗ {(video_title or video_id)[:20]}... " - f"{err_display}[/red]" - ), - visible=True, - ) - - return False - - async def _process_with_dashboard( - self, - video_ids: list[str], - playlist_name: str = "Queue", - is_single_video: bool = False, - ) -> int: - """Process a list of videos using the Advanced Dashboard UI.""" - from ..ui.dashboard import PipelineDashboard - - # Initialize Dashboard FIRST to capture all output - # Adjust concurrency display: if total_videos < max_concurrency, - # only show needed workers - actual_concurrency = min(len(video_ids), config.max_concurrent_videos) - - dashboard = PipelineDashboard( - total_videos=len(video_ids), - concurrency=actual_concurrency, - playlist_name=playlist_name, - model_name=self.model, - ) - - success_count = 0 - video_titles = {} - - # Run Live Display (inline, not full screen) - # We start it immediately to show "Fetching metadata..." state - with Live(dashboard, refresh_per_second=10, console=console, screen=False): - # --- Phase 1: Metadata Fetching --- - TITLE_FETCH_CONCURRENCY = 10 - if not is_single_video: - dashboard.update_overall_status( - f"[cyan]📋 Fetching metadata for {len(video_ids)} videos...[/cyan]" - ) - - title_semaphore = asyncio.Semaphore(TITLE_FETCH_CONCURRENCY) - - async def fetch_title_safe(vid: str) -> str: - async with title_semaphore: - try: - return await asyncio.to_thread(get_video_title, vid) - except Exception: - return vid - - # Fetch titles - titles = await asyncio.gather(*(fetch_title_safe(vid) for vid in video_ids)) - video_titles = dict(zip(video_ids, titles, strict=True)) - - # --- Phase 2: Processing --- - if not is_single_video: - dashboard.update_overall_status("[bold blue]Total Progress[/bold blue]") - - # Determine base output folder - if is_single_video: - base_folder = self.output_dir - else: - base_folder = self.output_dir / sanitize_filename(playlist_name) - base_folder.mkdir(parents=True, exist_ok=True) - - # Worker Queue Implementation - queue: asyncio.Queue[str] = asyncio.Queue() - for vid in video_ids: - queue.put_nowait(vid) - - async def worker(worker_idx: int, task_id: TaskID) -> None: - nonlocal success_count - while not queue.empty(): - try: - video_id = queue.get_nowait() - except asyncio.QueueEmpty: - break - - title = video_titles.get(video_id, video_id) - safe_title = sanitize_filename(title) - - if is_single_video: - video_folder = base_folder / safe_title - output_path = video_folder / f"{safe_title}.md" - else: - output_path = base_folder / f"{safe_title}.md" - - # Update status - dashboard.update_worker( - worker_idx, f"[yellow]{title[:30]}...[/yellow]" - ) - - try: - result = await self.process_video( - video_id, - output_path, - progress=dashboard.worker_progress, - task_id=task_id, - video_title=title, - is_playlist=not is_single_video, - ) - - if result: - success_count += 1 - dashboard.add_completion(title) - else: - dashboard.add_failure(title) - - except Exception as e: - logger.error(f"Worker {worker_idx} failed on {video_id}: {e}") - dashboard.update_worker(worker_idx, f"[red]Error: {e}[/red]") - dashboard.add_failure(title) - await asyncio.sleep(2) - finally: - queue.task_done() - - # Worker done - dashboard.update_worker(worker_idx, "[dim]Idle[/dim]") - - try: - workers = [ - asyncio.create_task(worker(i, dashboard.worker_tasks[i])) - for i in range(actual_concurrency) - ] - await asyncio.gather(*workers) - except Exception as e: - logger.error(f"Dashboard execution failed: {e}") - - # Print summary table after dashboard closes - self._print_summary(dashboard) - - return success_count - - def _print_summary(self, dashboard: "PipelineDashboard") -> None: - """Print a summary table of the run.""" - if not dashboard.recent_completions and not dashboard.recent_failures: - return - - summary_table = Table( - title="📊 Processing Summary", - border_style="cyan", - show_header=True, - header_style="bold magenta", - ) - summary_table.add_column("Status", justify="center") - summary_table.add_column("Video Title", style="dim") - - # Add failures first (more important) - if dashboard.recent_failures: - for fail in dashboard.recent_failures: - summary_table.add_row("[bold red]FAILED[/bold red]", fail) - - # Add successes - if dashboard.recent_completions: - for comp in dashboard.recent_completions: - summary_table.add_row("[green]SUCCESS[/green]", comp) - - console.print("\n") - console.print(summary_table) - console.print( - f"\n[bold]Total Completed:[/bold] " - f"{dashboard.overall_progress.tasks[0].completed}/" - f"{dashboard.overall_progress.tasks[0].total}" - ) - console.print("[dim]Check logs for detailed error reports.[/dim]\n") - - async def process_playlist( - self, playlist_id: str, playlist_name: str = "playlist" - ) -> int: - """Process playlist with concurrent dynamic progress bars.""" - video_ids = await extract_playlist_videos(playlist_id) - return await self._process_with_dashboard(video_ids, playlist_name) - - async def run(self, url: str) -> None: - """ - Run the pipeline for a given YouTube URL. - - Args: - url: YouTube video or playlist URL. - """ - # Validate Provider Credentials - if not self.validate_provider(): - return - - try: - # Parse URL - parsed = parse_youtube_url(url) - - if parsed.url_type == "video": - if not parsed.video_id: - console.print("[red]Error: Video ID could not be extracted[/red]") - return - - await self._process_with_dashboard( - [parsed.video_id], - playlist_name="Single Video", - is_single_video=True, - ) - - # Summary is already printed by _process_with_dashboard - - elif parsed.url_type == "playlist": - if not parsed.playlist_id: - console.print( - "[red]Error: Playlist ID could not be extracted[/red]" - ) - return - - # Fetch basic playlist info first - handled in dashboard now - # if needed or kept minimal. Actually, playlist title fetching - # is useful to show BEFORE starting but _process_with_dashboard - # fetches metadata anyway. - # However, to pass playlist_name to dashboard, we might want it. - # But waiting for title can be slow. - # Let's let the dashboard handle titles for videos. - # For playlist title, we can try fast fetch or default to ID. - - # Fetching playlist title here is blocking/slow if not careful. - # Let's just use ID as name initially or fetch it quickly. - # The original code did fetch it. - - # To reduce redundancy, we remove the print statement - # "Playlist: ..." - playlist_title, _ = await asyncio.to_thread( - get_playlist_info, parsed.playlist_id - ) - - # Removed redundant print: - # console.print(f"[cyan]📑 Playlist:[/cyan] {playlist_title}\n") - - await self.process_playlist(parsed.playlist_id, playlist_title) - - # Summary handled by dashboard - - except ValueError as e: - console.print(f"[red]Input Error: {e}[/red]") - except Exception as e: - console.print(f"[red]Unexpected Error: {e}[/red]") - logger.exception("Pipeline run failed") +# Re-export everything from the new location +from yt_study.core.orchestrator import * # noqa: F401, F403 diff --git a/src/yt_study/setup_wizard.py b/src/yt_study/setup_wizard.py index cb52429..823d111 100644 --- a/src/yt_study/setup_wizard.py +++ b/src/yt_study/setup_wizard.py @@ -1,420 +1,4 @@ -"""Configuration wizard for yt-study.""" +"""Backward compatibility - setup_wizard moved to core.setup_wizard.""" -from pathlib import Path -from typing import Any - -from rich.console import Console -from rich.panel import Panel -from rich.prompt import Confirm, Prompt -from rich.table import Table - - -console = Console() - - -# API key configuration for different providers -PROVIDER_CONFIG: dict[str, dict[str, Any]] = { - "gemini": { - "name": "Google Gemini", - "env_var": "GEMINI_API_KEY", - "api_url": "https://aistudio.google.com/app/apikey", - "keywords": ["gemini", "vertex"], - }, - "openai": { - "name": "OpenAI (ChatGPT)", - "env_var": "OPENAI_API_KEY", - "api_url": "https://platform.openai.com/api-keys", - "keywords": ["gpt", "openai", "o1"], - }, - "anthropic": { - "name": "Anthropic (Claude)", - "env_var": "ANTHROPIC_API_KEY", - "api_url": "https://console.anthropic.com/settings/keys", - "keywords": ["claude", "anthropic"], - }, - "groq": { - "name": "Groq", - "env_var": "GROQ_API_KEY", - "api_url": "https://console.groq.com/keys", - "keywords": ["groq"], - }, - "xai": { - "name": "xAI (Grok)", - "env_var": "XAI_API_KEY", - "api_url": "https://console.x.ai/", - "keywords": ["grok", "xai"], - }, - "mistral": { - "name": "Mistral AI", - "env_var": "MISTRAL_API_KEY", - "api_url": "https://console.mistral.ai/api-keys/", - "keywords": ["mistral"], - }, - "cohere": { - "name": "Cohere", - "env_var": "COHERE_API_KEY", - "api_url": "https://dashboard.cohere.com/api-keys", - "keywords": ["cohere", "command"], - }, - "deepseek": { - "name": "DeepSeek", - "env_var": "DEEPSEEK_API_KEY", - "api_url": "https://platform.deepseek.com/api_keys", - "keywords": ["deepseek"], - }, -} - - -def get_config_path() -> Path: - """Get path to user config file.""" - config_dir = Path.home() / ".yt-study" - config_dir.mkdir(exist_ok=True) - return config_dir / "config.env" - - -def load_config() -> dict[str, str]: - """Load existing configuration.""" - config_path = get_config_path() - loaded_config = {} - - if config_path.exists(): - try: - with config_path.open(encoding="utf-8") as f: - for line in f: - line = line.strip() - if line and not line.startswith("#") and "=" in line: - key, value = line.split("=", 1) - loaded_config[key.strip()] = value.strip() - except Exception: - # If corrupted, return empty - pass - - return loaded_config - - -def save_config(new_config: dict[str, str]) -> None: - """ - Save configuration to file, preserving existing keys. - - Args: - new_config: Dictionary of new configuration values to merge/update. - """ - config_path = get_config_path() - current_config = load_config() - - # Merge new config into current config - current_config.update(new_config) - - with config_path.open("w", encoding="utf-8") as f: - f.write("# yt-study Configuration\n") - f.write("# Generated by yt-study setup wizard\n\n") - - # Sort keys for consistent output, prioritize critical ones - priority_keys = ["DEFAULT_MODEL", "OUTPUT_DIR", "MAX_CONCURRENT_VIDEOS"] - for key in priority_keys: - if key in current_config: - f.write(f"{key}={current_config[key]}\n") - - # Write remaining keys - for key, value in sorted(current_config.items()): - if key not in priority_keys: - f.write(f"{key}={value}\n") - - console.print( - f"\n[green]✓[/green] Configuration saved to: [cyan]{config_path}[/cyan]" - ) - - -def get_available_models() -> dict[str, list[str]]: - """Fetch available models from LiteLLM.""" - try: - # Lazy import to avoid slow startup if not running setup - from litellm import model_list - - # Group models by provider - provider_models: dict[str, list[str]] = {} - - for model in model_list: - # Determine provider from model name - provider = None - model_lower = model.lower() - - for prov_key, prov_config in PROVIDER_CONFIG.items(): - if any(keyword in model_lower for keyword in prov_config["keywords"]): - provider = prov_key - break - - if provider: - if provider not in provider_models: - provider_models[provider] = [] - provider_models[provider].append(model) - - # Sort models within each provider - for provider in provider_models: - provider_models[provider] = sorted(set(provider_models[provider])) - - return provider_models - - except Exception as e: - console.print(f"[yellow]⚠ Could not fetch models from LiteLLM: {e}[/yellow]") - console.print("[yellow]Using fallback model list...[/yellow]") - - # Fallback to curated list - return { - "gemini": [ - "gemini/gemini-2.0-flash-exp", - "gemini/gemini-2.0-flash", - "gemini/gemini-1.5-pro", - "gemini/gemini-1.5-flash", - ], - "openai": [ - "gpt-4o", - "gpt-4o-mini", - "gpt-4-turbo", - "o1", - "o1-mini", - ], - "anthropic": [ - "anthropic/claude-3-5-sonnet-20241022", - "anthropic/claude-3-5-haiku-20241022", - "anthropic/claude-3-opus-20240229", - ], - "groq": [ - "groq/llama-3.3-70b-versatile", - "groq/llama-3.1-8b-instant", - "groq/mixtral-8x7b-32768", - ], - "xai": [ - "xai/grok-2-latest", - "xai/grok-2-vision-latest", - ], - } - - -def select_provider(available_models: dict[str, list[str]]) -> str: - """Interactive provider selection.""" - console.print("\n[bold cyan]Select LLM Provider:[/bold cyan]\n") - - # Create table of providers - table = Table(show_header=True, header_style="bold magenta") - table.add_column("#", style="dim", width=4) - table.add_column("Provider", style="cyan") - table.add_column("Models Available", style="dim") - - providers_list = [] - for prov_key, _prov_config in PROVIDER_CONFIG.items(): - if prov_key in available_models: - providers_list.append(prov_key) - - for i, provider_key in enumerate(providers_list, 1): - config_data = PROVIDER_CONFIG[provider_key] - model_count = len(available_models.get(provider_key, [])) - table.add_row(str(i), config_data["name"], f"{model_count} models") - - console.print(table) - console.print(f"\n[dim]Total providers: {len(providers_list)}[/dim]") - - while True: - choice = Prompt.ask( - "\nSelect provider", - choices=[str(i) for i in range(1, len(providers_list) + 1)], - ) - return providers_list[int(choice) - 1] - - -def select_model(provider_key: str, available_models: dict[str, list[str]]) -> str: - """Interactive model selection.""" - provider_config = PROVIDER_CONFIG[provider_key] - models = available_models.get(provider_key, []) - - if not models: - console.print(f"[yellow]No models found for {provider_config['name']}[/yellow]") - return f"{provider_key}/default" - - console.print(f"\n[bold cyan]Select {provider_config['name']} Model:[/bold cyan]\n") - console.print(f"[dim]Showing {len(models)} available models[/dim]\n") - - # Show models in pages of 20 - page_size = 20 - current_page = 0 - - while True: - start_idx = current_page * page_size - end_idx = min(start_idx + page_size, len(models)) - page_models = models[start_idx:end_idx] - - # Create table of models - table = Table(show_header=True, header_style="bold magenta") - table.add_column("#", style="dim", width=4) - table.add_column("Model", style="green") - - for i, model in enumerate(page_models, start_idx + 1): - # Highlight recommended models - model_display = model - if "flash" in model.lower() or "mini" in model.lower(): - model_display = f"{model} [dim](fast)[/dim]" - elif ( - "pro" in model.lower() - or "turbo" in model.lower() - or "sonnet" in model.lower() - ): - model_display = f"{model} [dim](powerful)[/dim]" - - table.add_row(str(i), model_display) - - console.print(table) - - # Navigation info - total_pages = (len(models) + page_size - 1) // page_size - console.print( - f"\n[dim]Page {current_page + 1}/{total_pages} | " - f"Showing {start_idx + 1}-{end_idx} of {len(models)} " - f"models[/dim]" - ) - - if total_pages > 1: - console.print( - "[dim]Type 'n' for next page, 'p' for previous page, " - "or model number to select[/dim]" - ) - - choice = Prompt.ask("\nSelect model (or n/p for navigation)") - - if choice.lower() == "n" and current_page < total_pages - 1: - current_page += 1 - console.clear() - console.print( - f"\n[bold cyan]Select {provider_config['name']} Model:[/bold cyan]\n" - ) - continue - elif choice.lower() == "p" and current_page > 0: - current_page -= 1 - console.clear() - console.print( - f"\n[bold cyan]Select {provider_config['name']} Model:[/bold cyan]\n" - ) - continue - elif choice.isdigit() and 1 <= int(choice) <= len(models): - selected = models[int(choice) - 1] - - # Ensure Gemini models have correct prefix for Google AI Studio - if ( - provider_key == "gemini" - and not selected.startswith("gemini/") - and not selected.startswith("vertex_ai/") - ): - return f"gemini/{selected}" - - return selected - - -def get_api_key(provider_key: str, existing_key: str | None = None) -> str: - """Prompt for API key.""" - provider = PROVIDER_CONFIG[provider_key] - - console.print(f"\n[bold yellow]API Key Required:[/bold yellow] {provider['name']}") - console.print( - f"[dim]Get your API key from:[/dim] " - f"[link={provider['api_url']}]{provider['api_url']}[/link]\n" - ) - - if existing_key: - masked = ( - f"{existing_key[:8]}...{existing_key[-4:]}" - if len(existing_key) > 12 - else "***" - ) - use_existing = Confirm.ask(f"Use existing key ({masked})?", default=True) - if use_existing: - return existing_key - - while True: - api_key = Prompt.ask("Enter your API key", password=True) - if api_key and len(api_key) > 10: # Basic validation - return api_key - console.print("[red]Invalid API key. Please try again.[/red]") - - -def run_setup_wizard(force: bool = False) -> dict[str, str]: - """Run interactive setup wizard.""" - console.print( - Panel( - "[bold cyan]🎓 yt-study Setup Wizard[/bold cyan]\n\n" - "Configure your LLM provider and API keys\n" - "[dim]Powered by LiteLLM - 400+ models supported[/dim]", - border_style="cyan", - expand=False, - ) - ) - - # Load existing config - current_config = load_config() - - if current_config and not force: - console.print("\n[yellow]Existing configuration found.[/yellow]") - reconfigure = Confirm.ask("Do you want to reconfigure?", default=False) - if not reconfigure: - console.print("[green]Using existing configuration.[/green]") - return current_config - - # Fetch available models from LiteLLM - console.print("\n[cyan]Fetching available models from LiteLLM...[/cyan]") - available_models = get_available_models() - console.print( - f"[green]✓ Found {sum(len(m) for m in available_models.values())} " - f"models across {len(available_models)} providers[/green]" - ) - - # Select provider - provider_key = select_provider(available_models) - - # Select model - model = select_model(provider_key, available_models) - - # Get API key - provider_info = PROVIDER_CONFIG[provider_key] - existing_key = current_config.get(provider_info["env_var"]) - api_key = get_api_key(provider_key, existing_key) - - # Optional: Configure output directory - console.print("\n[bold cyan]Output Directory:[/bold cyan]") - default_output = str(Path.cwd() / "output") - # If output dir already in config, use it as default - if "OUTPUT_DIR" in current_config: - default_output = current_config["OUTPUT_DIR"] - - output_dir = Prompt.ask("Where should notes be saved?", default=default_output) - - # Optional: Configure concurrency - console.print("\n[bold cyan]Concurrency:[/bold cyan]") - default_concurrency = current_config.get("MAX_CONCURRENT_VIDEOS", "5") - concurrency = Prompt.ask( - "Max concurrent videos to process?", default=default_concurrency - ) - - # Build config updates - new_config = { - "DEFAULT_MODEL": model, - provider_info["env_var"]: api_key, - "OUTPUT_DIR": output_dir, - "MAX_CONCURRENT_VIDEOS": concurrency, - } - - # Save configuration (merging with existing) - save_config(new_config) - - console.print("\n[bold green]✓ Setup complete![/bold green]") - console.print( - Panel( - f"[dim]Selected model:[/dim] [cyan]{model}[/cyan]\n" - f"[dim]Configuration saved to:[/dim] [cyan]{get_config_path()}[/cyan]\n\n" - "[bold]Next Steps:[/bold]\n" - 'Run: [green]yt-study process "URL"[/green]', - title="🎉 Ready to go", - border_style="green", - ) - ) - - # Return merged config - current_config.update(new_config) - return current_config +# Re-export everything from the new location +from yt_study.core.setup_wizard import * # noqa: F401, F403 diff --git a/src/yt_study/youtube/__init__.py b/src/yt_study/youtube/__init__.py index c09344d..9d869cb 100644 --- a/src/yt_study/youtube/__init__.py +++ b/src/yt_study/youtube/__init__.py @@ -1 +1,3 @@ -"""YouTube module for handling video URLs, playlists, and transcripts.""" +"""Backward compatibility - youtube package moved to core.youtube.""" + +from yt_study.core.youtube import * # noqa: F401, F403 diff --git a/src/yt_study/youtube/metadata.py b/src/yt_study/youtube/metadata.py index 14f2bca..e5f1206 100644 --- a/src/yt_study/youtube/metadata.py +++ b/src/yt_study/youtube/metadata.py @@ -1,163 +1,3 @@ -"""Video metadata extraction using pytubefix.""" +"""Backward compatibility - youtube.metadata moved to core.youtube.metadata.""" -import logging -from dataclasses import dataclass -from typing import Any - -from pytubefix import Playlist, YouTube -from rich.console import Console - - -console = Console() -logger = logging.getLogger(__name__) - - -@dataclass -class VideoChapter: - """ - A video chapter with title and time range. - - Attributes: - title: Chapter title. - start_seconds: Start time in seconds. - end_seconds: End time in seconds (None for the last chapter). - """ - - title: str - start_seconds: int - end_seconds: int | None = None - - -def get_video_chapters(video_id: str) -> list[VideoChapter]: - """ - Get chapters from a YouTube video. - - Note: This function performs blocking network I/O. - - Args: - video_id: YouTube video ID. - - Returns: - List of VideoChapter objects, empty if no chapters found. - """ - try: - url = f"https://www.youtube.com/watch?v={video_id}" - yt = YouTube(url) - - # Access chapters if available - # pytubefix properties trigger network calls - if hasattr(yt, "chapters") and yt.chapters: - chapters: list[VideoChapter] = [] - chapter_data = yt.chapters - - for i, chapter in enumerate(chapter_data): - # Handle pytubefix chapter object structure (dict or object) - start_time = _get_attr_or_item(chapter, "start_seconds", 0) - title = _get_attr_or_item(chapter, "title", f"Chapter {i + 1}") - - # Calculate end time (start of next chapter or None for last) - end_time = None - if i < len(chapter_data) - 1: - next_chapter = chapter_data[i + 1] - end_time = _get_attr_or_item(next_chapter, "start_seconds", None) - - chapters.append( - VideoChapter( - title=str(title), - start_seconds=int(start_time), - end_seconds=int(end_time) if end_time is not None else None, - ) - ) - - return chapters - - except Exception as e: - logger.debug(f"Could not fetch chapters for {video_id}: {e}") - - return [] - - -def get_video_title(video_id: str) -> str: - """ - Get the title of a YouTube video. - - Note: This function performs blocking network I/O. - - Args: - video_id: YouTube video ID. - - Returns: - Video title, or video ID if title cannot be fetched. - """ - try: - url = f"https://www.youtube.com/watch?v={video_id}" - yt = YouTube(url) - title = yt.title - - if title: - return str(title) - - except Exception as e: - logger.warning(f"Could not fetch title for {video_id}: {e}") - - # Fallback to video ID - return video_id - - -def get_video_duration(video_id: str) -> int: - """ - Get video duration in seconds. - - Note: This function performs blocking network I/O. - - Args: - video_id: YouTube video ID. - - Returns: - Duration in seconds, 0 if cannot be fetched. - """ - try: - url = f"https://www.youtube.com/watch?v={video_id}" - yt = YouTube(url) - return int(yt.length) - except Exception as e: - logger.warning(f"Could not fetch duration for {video_id}: {e}") - return 0 - - -def get_playlist_info(playlist_id: str) -> tuple[str, int]: - """ - Get playlist title and video count. - - Note: This function performs blocking network I/O. - - Args: - playlist_id: YouTube playlist ID. - - Returns: - Tuple of (title, video_count). - """ - try: - url = f"https://www.youtube.com/playlist?list={playlist_id}" - playlist = Playlist(url) - - # Pytube's title might fail if playlist is private/invalid - title = getattr(playlist, "title", f"playlist_{playlist_id}") - - # Getting length requires fetching the page - # list(playlist.video_urls) is robust but slow for huge playlists - # For metadata, it's acceptable. - count = len(list(playlist.video_urls)) - - return str(title), count - - except Exception as e: - logger.warning(f"Could not fetch playlist info: {e}") - return f"playlist_{playlist_id}", 0 - - -def _get_attr_or_item(obj: Any, key: str, default: Any = None) -> Any: - """Helper to get value from object attribute or dict key.""" - if isinstance(obj, dict): - return obj.get(key, default) - return getattr(obj, key, default) +from yt_study.core.youtube.metadata import * # noqa: F401, F403 diff --git a/src/yt_study/youtube/parser.py b/src/yt_study/youtube/parser.py index 3830189..5a572b0 100644 --- a/src/yt_study/youtube/parser.py +++ b/src/yt_study/youtube/parser.py @@ -1,117 +1,3 @@ -"""YouTube URL parser for video and playlist detection.""" +"""Backward compatibility - youtube.parser moved to core.youtube.parser.""" -import re -from dataclasses import dataclass -from urllib.parse import parse_qs, urlparse - - -@dataclass -class ParsedURL: - """ - Parsed YouTube URL information. - - Attributes: - url_type: Type of the URL ('video' or 'playlist'). - video_id: Extracted video ID (if present). - playlist_id: Extracted playlist ID (if present). - """ - - url_type: str # 'video' or 'playlist' - video_id: str | None = None - playlist_id: str | None = None - - -def extract_video_id(url: str) -> str | None: - """ - Extract video ID from various YouTube URL formats. - - Supports: - - Standard: https://www.youtube.com/watch?v=VIDEO_ID - - Short: https://youtu.be/VIDEO_ID - - Embed: https://www.youtube.com/embed/VIDEO_ID - - V-path: https://www.youtube.com/v/VIDEO_ID - - Shorts: https://www.youtube.com/shorts/VIDEO_ID - - Args: - url: The YouTube URL string. - - Returns: - The 11-character video ID if found, else None. - """ - # Common patterns for YouTube Video IDs (11 chars, alphanumeric + _ -) - patterns = [ - r"(?:v=|\/)([0-9A-Za-z_-]{11}).*", - r"youtu\.be\/([0-9A-Za-z_-]{11})", - r"embed\/([0-9A-Za-z_-]{11})", - r"shorts\/([0-9A-Za-z_-]{11})", - ] - - for pattern in patterns: - match = re.search(pattern, url) - if match: - return match.group(1) - - return None - - -def extract_playlist_id(url: str) -> str | None: - """ - Extract playlist ID from YouTube playlist URL. - - Supports: - - https://www.youtube.com/playlist?list=PLAYLIST_ID - - https://www.youtube.com/watch?v=VIDEO_ID&list=PLAYLIST_ID - - Args: - url: The YouTube URL string. - - Returns: - The playlist ID if found, else None. - """ - try: - parsed = urlparse(url) - query_params = parse_qs(parsed.query) - - if "list" in query_params: - return query_params["list"][0] - except Exception: - # Fail gracefully on malformed URLs - pass - - return None - - -def parse_youtube_url(url: str) -> ParsedURL: - """ - Parse a YouTube URL and determine if it's a video or playlist. - - Prioritizes playlist ID if 'list' parameter is present, - but also extracts video ID if available (e.g. watching a playlist). - - Args: - url: YouTube URL (video or playlist) - - Returns: - ParsedURL object with url_type and relevant IDs - - Raises: - ValueError: If URL is not a valid YouTube URL (neither video nor playlist) - """ - if not url or not isinstance(url, str): - raise ValueError("URL must be a non-empty string") - - # Check for playlist first - playlist_id = extract_playlist_id(url) - if playlist_id: - # It's a playlist URL - video_id = extract_video_id(url) # Might have both - return ParsedURL( - url_type="playlist", playlist_id=playlist_id, video_id=video_id - ) - - # Check for video - video_id = extract_video_id(url) - if video_id: - return ParsedURL(url_type="video", video_id=video_id) - - raise ValueError(f"Invalid YouTube URL: {url}") +from yt_study.core.youtube.parser import * # noqa: F401, F403 diff --git a/src/yt_study/youtube/playlist.py b/src/yt_study/youtube/playlist.py index 706867a..9a09b31 100644 --- a/src/yt_study/youtube/playlist.py +++ b/src/yt_study/youtube/playlist.py @@ -1,97 +1,3 @@ -"""Playlist video extraction using pytubefix.""" +"""Backward compatibility - youtube.playlist moved to core.youtube.playlist.""" -import asyncio -import logging - -from pytubefix import Playlist -from rich.console import Console - - -console = Console() -logger = logging.getLogger(__name__) - - -class PlaylistError(Exception): - """Exception raised for playlist-related errors.""" - - pass - - -async def extract_playlist_videos(playlist_id: str) -> list[str]: - """ - Extract all video IDs from a YouTube playlist with retry logic. - - This function handles the blocking network calls of pytubefix by offloading - them to a separate thread, ensuring the asyncio event loop remains responsive. - - Args: - playlist_id: YouTube playlist ID. - - Returns: - List of video IDs. - - Raises: - PlaylistError: If playlist cannot be accessed after retries. - """ - max_retries = 3 - last_error = None - - for attempt in range(max_retries): - try: - # Wrap blocking pytubefix logic in a thread - video_ids = await asyncio.to_thread(_extract_sync, playlist_id, attempt) - - if not video_ids: - # Should have been raised in _extract_sync if empty, but double check - raise ValueError( - f"No videos found in playlist (Attempt {attempt + 1}/{max_retries})" - ) - - logger.info(f"Found {len(video_ids)} videos in playlist") - return video_ids - - except Exception as e: - last_error = e - logger.warning(f"Playlist extraction attempt {attempt + 1} failed: {e}") - if attempt < max_retries - 1: - wait_time = 2**attempt # Exponential backoff - logger.warning(f"Retrying in {wait_time}s...") - await asyncio.sleep(wait_time) - - logger.error( - f"Failed to extract playlist videos after {max_retries} attempts: {last_error}" - ) - raise PlaylistError(f"Could not access playlist {playlist_id}: {str(last_error)}") - - -def _extract_sync(playlist_id: str, attempt: int) -> list[str]: - """Blocking helper to extract videos using pytubefix.""" - playlist_url = f"https://www.youtube.com/playlist?list={playlist_id}" - playlist = Playlist(playlist_url) - - # Access playlist title to trigger loading - try: - title = playlist.title - if attempt == 0: - logger.info(f"Found playlist: {title}") - except Exception: - # Title fetch might fail but video extraction might still work - logger.warning(f"Could not fetch playlist title on attempt {attempt + 1}") - - video_ids = [] - - # Extract video IDs from URLs (waits for internal generator) - # This loop triggers network requests - for url in playlist.video_urls: - if "v=" in url: - try: - # Robust ID extraction - video_id = url.split("v=")[1].split("&")[0] - video_ids.append(video_id) - except IndexError: - continue - - if not video_ids: - raise ValueError("No videos found in playlist") - - return video_ids +from yt_study.core.youtube.playlist import * # noqa: F401, F403 diff --git a/src/yt_study/youtube/transcript.py b/src/yt_study/youtube/transcript.py index 1f8a2ac..392156b 100644 --- a/src/yt_study/youtube/transcript.py +++ b/src/yt_study/youtube/transcript.py @@ -1,297 +1,3 @@ -"""Transcript fetching with multi-language support.""" +"""Backward compatibility - youtube.transcript moved to core.youtube.transcript.""" -import asyncio -import logging -from dataclasses import dataclass -from typing import Any - -from rich.console import Console -from youtube_transcript_api import YouTubeTranscriptApi -from youtube_transcript_api._errors import ( - IpBlocked, - NoTranscriptFound, - RequestBlocked, - TranscriptsDisabled, - VideoUnavailable, -) - -from .metadata import VideoChapter - - -console = Console() -logger = logging.getLogger(__name__) - - -@dataclass -class TranscriptSegment: - """ - A segment of transcript text with timing. - - Attributes: - text: The spoken text. - start: Start time in seconds. - duration: Duration of the segment in seconds. - """ - - text: str - start: float - duration: float - - -@dataclass -class VideoTranscript: - """ - Complete transcript for a video. - - Attributes: - video_id: The YouTube video ID. - segments: List of transcript segments. - language: Language name (e.g., 'English'). - language_code: Language code (e.g., 'en'). - is_generated: Whether the transcript is auto-generated. - """ - - video_id: str - segments: list[TranscriptSegment] - language: str - language_code: str - is_generated: bool - - def to_text(self) -> str: - """Convert transcript segments to continuous text.""" - return " ".join(segment.text for segment in self.segments) - - -class TranscriptError(Exception): - """Exception raised for transcript-related errors.""" - - pass - - -class YouTubeIPBlockError(TranscriptError): - """Exception raised when YouTube blocks IP.""" - - pass - - -async def fetch_transcript( - video_id: str, languages: list[str] | None = None -) -> VideoTranscript: - """ - Fetch transcript for a YouTube video with language fallback and retry logic. - - Priority: - 1. Manual transcript in preferred language - 2. Auto-generated transcript in preferred language - 3. Manual transcript in any available language - 4. Auto-generated transcript in any available language - 5. Translated transcript to English - - Args: - video_id: YouTube video ID. - languages: Preferred language codes (e.g., ['en', 'hi']). Defaults to ['en']. - - Returns: - VideoTranscript object. - - Raises: - TranscriptError: If no transcript is available. - """ - if languages is None: - languages = ["en"] - - retries = 3 - - for attempt in range(retries): - try: - # Wrap blocking YouTubeTranscriptApi calls in a thread - # This is critical to prevent blocking the asyncio event loop - # during concurrency - raw_transcript, transcript_meta, log_msg = await asyncio.to_thread( - _fetch_sync, video_id, languages - ) - - logger.info(log_msg) - - # Convert to our format - segments = [] - for segment in raw_transcript: - # Handle both dict (standard) and object - # (FetchedTranscriptSnippet) formats - if isinstance(segment, dict): - text = segment.get("text", "") - start = segment.get("start", 0.0) - duration = segment.get("duration", 0.0) - else: - # Fallback for object-based returns - text = getattr(segment, "text", "") - start = getattr(segment, "start", 0.0) - duration = getattr(segment, "duration", 0.0) - - segments.append( - TranscriptSegment( - text=text, start=float(start), duration=float(duration) - ) - ) - - return VideoTranscript( - video_id=video_id, - segments=segments, - language=transcript_meta.language, - language_code=transcript_meta.language_code, - is_generated=transcript_meta.is_generated, - ) - - except (TranscriptsDisabled, VideoUnavailable) as e: - # Fatal errors, do not retry - logger.error(f"Transcript unavailable for {video_id}: {e}") - raise TranscriptError( - f"Transcripts are disabled or video is unavailable: {video_id}" - ) from e - - except (TranscriptError, NoTranscriptFound): - # Already handled or strictly not found, do not retry - raise - - except (IpBlocked, RequestBlocked) as e: - # Specifically handle IP blocking - logger.error(f"YouTube IP Block detected for {video_id}") - raise YouTubeIPBlockError( - "YouTube is blocking requests from your IP. " - "Please try using a VPN, proxies, or wait a while." - ) from e - - except Exception as e: - err_str = str(e) - if "blocking requests from your IP" in err_str: - logger.error(f"YouTube IP Block detected for {video_id}: {e}") - raise YouTubeIPBlockError( - "YouTube is blocking requests from your IP. " - "Please try using a VPN, proxies, or wait a while." - ) from e - - if attempt < retries - 1: - wait_time = 2**attempt - logger.warning( - f"Transcript fetch failed ({str(e)}), retrying in {wait_time}s..." - ) - await asyncio.sleep(wait_time) - else: - logger.error(f"Failed to fetch transcript for {video_id}: {e}") - raise TranscriptError(f"Could not fetch transcript: {str(e)}") from e - - # Should be unreachable due to raise in loop - raise TranscriptError(f"Failed to fetch transcript for {video_id}") - - -def _fetch_sync(video_id: str, languages: list[str]) -> tuple[Any, Any, str]: - """Blocking helper to interact with YouTubeTranscriptApi.""" - ytt_api = YouTubeTranscriptApi() - - # List all available transcripts - # This list call can fail with TranscriptsDisabled or VideoUnavailable - transcript_list = ytt_api.list(video_id) - - transcript = None - found_msg = "" - - # Strategy 1: Find manual transcript in preferred language - try: - transcript = transcript_list.find_manually_created_transcript(languages) - found_msg = f"Found manual transcript: {transcript.language}" - except NoTranscriptFound: - pass - - # Strategy 2: Try auto-generated in preferred language - if not transcript: - try: - transcript = transcript_list.find_generated_transcript(languages) - found_msg = f"Using auto-generated transcript: {transcript.language}" - except NoTranscriptFound: - pass - - # Strategy 3: Try any manual transcript - if not transcript: - try: - # Get all language codes available - all_codes = [t.language_code for t in transcript_list] - transcript = transcript_list.find_manually_created_transcript(all_codes) - found_msg = f"Using manual transcript in {transcript.language}" - except NoTranscriptFound: - pass - - # Strategy 4: Last resort - try any available transcript and translate if needed - if not transcript: - try: - # list(transcript_list) returns iterable of Transcript objects - available = list(transcript_list) - if not available: - raise NoTranscriptFound(video_id, languages, []) - - first_available = available[0] - - # Try to translate to English if not English already and requested - if "en" in languages and first_available.language_code != "en": - if first_available.is_translatable: - transcript = first_available.translate("en") - found_msg = f"Translated {first_available.language} -> English" - else: - transcript = first_available - found_msg = ( - f"Using {transcript.language} (translation not available)" - ) - else: - transcript = first_available - found_msg = f"Using {transcript.language}" - - except Exception as e: - # If we really can't find anything - if isinstance(e, NoTranscriptFound): - raise - raise TranscriptError(f"No usable transcript found: {e}") from e - - # Fetch the actual transcript data - raw_transcript = transcript.fetch() - return raw_transcript, transcript, found_msg - - -def split_transcript_by_chapters( - transcript: VideoTranscript, chapters: list[VideoChapter] -) -> dict[str, str]: - """ - Split a video transcript by chapters. - - Args: - transcript: VideoTranscript object. - chapters: List of VideoChapter objects. - - Returns: - Dictionary mapping chapter titles to their transcript text. - """ - chapter_transcripts = {} - - for chapter in chapters: - # Filter segments for this chapter - chapter_segments = [] - - for segment in transcript.segments: - segment_start = segment.start - - # Check if segment start is within chapter range - if chapter.end_seconds is None: - # Last chapter - include everything after start - if segment_start >= chapter.start_seconds: - chapter_segments.append(segment.text) - else: - # Middle chapters - include if in range - if ( - segment_start >= chapter.start_seconds - and segment_start < chapter.end_seconds - ): - chapter_segments.append(segment.text) - - # Combine segments for this chapter - chapter_text = " ".join(chapter_segments) - chapter_transcripts[chapter.title] = chapter_text - - return chapter_transcripts +from yt_study.core.youtube.transcript import * # noqa: F401, F403 diff --git a/uv.lock b/uv.lock index 00592f7..6109bec 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" [[package]] @@ -2465,7 +2465,7 @@ wheels = [ [[package]] name = "yt-study" -version = "0.1.7" +version = "0.2.1" source = { editable = "." } dependencies = [ { name = "aiohttp" }, From a361c39a862d63077834196fae6e66095ed10a0e Mon Sep 17 00:00:00 2001 From: mdabucse Date: Sat, 7 Feb 2026 20:15:33 +0530 Subject: [PATCH 2/3] fix(tests): update tests to use core imports and add fallback shims; ignore generated content - Updated test imports and patch targets to use yt_study.core.* - Updated backward compatibility shims - Added output/ to .gitignore - Fixed lint errors --- .gitignore | 5 +++- src/yt_study/pipeline/__init__.py | 1 + tests/conftest.py | 10 +++---- tests/test_cli.py | 10 +++---- tests/test_config.py | 2 +- tests/test_llm/test_generator.py | 14 +++++----- tests/test_llm/test_providers.py | 8 +++--- tests/test_pipeline/test_orchestrator.py | 34 +++++++++++------------- tests/test_setup_wizard.py | 32 +++++++++++++--------- tests/test_youtube/test_metadata.py | 2 +- tests/test_youtube/test_parser.py | 2 +- tests/test_youtube/test_playlist.py | 2 +- tests/test_youtube/test_transcript.py | 6 ++--- 13 files changed, 68 insertions(+), 60 deletions(-) diff --git a/.gitignore b/.gitignore index 3e0ff89..40b49dd 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,7 @@ nul .coverage drafts/ .omc -.claude \ No newline at end of file +.claude + +# Generated content +src/yt_study/output/ \ No newline at end of file diff --git a/src/yt_study/pipeline/__init__.py b/src/yt_study/pipeline/__init__.py index c3b37b9..b20f6e9 100644 --- a/src/yt_study/pipeline/__init__.py +++ b/src/yt_study/pipeline/__init__.py @@ -2,4 +2,5 @@ from yt_study.core.orchestrator import PipelineOrchestrator, sanitize_filename + __all__ = ["PipelineOrchestrator", "sanitize_filename"] diff --git a/tests/conftest.py b/tests/conftest.py index 45f70dc..0db75b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,7 +33,7 @@ def mock_config(monkeypatch): # Reload config to pick up env vars if necessary, # or just rely on Config loading from env. - from yt_study.config import config + from yt_study.core.config import config config.gemini_api_key = "dummy_gemini_key" config.openai_api_key = "dummy_openai_key" @@ -52,17 +52,17 @@ def mock_llm_provider(): @pytest.fixture def mock_transcript_api(mocker): """Mock YouTubeTranscriptApi class.""" - return mocker.patch("yt_study.youtube.transcript.YouTubeTranscriptApi") + return mocker.patch("yt_study.core.youtube.transcript.YouTubeTranscriptApi") @pytest.fixture def mock_pytube(mocker): """Mock pytubefix YouTube and Playlist classes.""" # We patch the classes where they are imported in metadata.py - mock_yt = mocker.patch("yt_study.youtube.metadata.YouTube") - mock_pl = mocker.patch("yt_study.youtube.metadata.Playlist") + mock_yt = mocker.patch("yt_study.core.youtube.metadata.YouTube") + mock_pl = mocker.patch("yt_study.core.youtube.metadata.Playlist") # Also patch in playlist.py if used there - mocker.patch("yt_study.youtube.playlist.Playlist", new=mock_pl) + mocker.patch("yt_study.core.youtube.playlist.Playlist", new=mock_pl) return mock_yt, mock_pl diff --git a/tests/test_cli.py b/tests/test_cli.py index 16ec627..c9c04c6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,7 +14,7 @@ @pytest.fixture def mock_orchestrator(): # noqa: ARG001 # Patch where PipelineOrchestrator is defined - with patch("yt_study.pipeline.orchestrator.PipelineOrchestrator") as mock: + with patch("yt_study.core.orchestrator.PipelineOrchestrator") as mock: instance = mock.return_value instance.run = AsyncMock() yield mock @@ -114,7 +114,7 @@ def test_process_missing_config(): """Test that missing config triggers setup check/error.""" with ( patch("yt_study.cli.check_config_exists", return_value=False), - patch("yt_study.setup_wizard.run_setup_wizard") as mock_setup, + patch("yt_study.core.setup_wizard.run_setup_wizard") as mock_setup, ): runner.invoke(app, ["process", "url"]) mock_setup.assert_called_once() @@ -146,7 +146,7 @@ def test_process_general_exception(mock_config_exists, mock_orchestrator): # no def test_setup_command(): """Test setup command triggers wizard.""" - with patch("yt_study.setup_wizard.run_setup_wizard") as mock_wizard: + with patch("yt_study.core.setup_wizard.run_setup_wizard") as mock_wizard: result = runner.invoke(app, ["setup"]) assert result.exit_code == 0 mock_wizard.assert_called_once() @@ -155,7 +155,7 @@ def test_setup_command(): def test_setup_import_error(): """Test setup command handling missing wizard module.""" # Simulate ImportError when importing setup_wizard - with patch.dict("sys.modules", {"yt_study.setup_wizard": None}): + with patch.dict("sys.modules", {"yt_study.core.setup_wizard": None}): # This approach is tricky because we are inside the test process. # Better to patch the specific import or function call if lazy. pass @@ -165,7 +165,7 @@ def test_ensure_setup_import_error(): """Test ensure_setup handles missing wizard.""" with ( patch("yt_study.cli.check_config_exists", return_value=False), - patch.dict("sys.modules", {"yt_study.setup_wizard": None}), + patch.dict("sys.modules", {"yt_study.core.setup_wizard": None}), ): # This won't work easily as the module is likely already imported. pass diff --git a/tests/test_config.py b/tests/test_config.py index c771c11..8e37a38 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,7 +3,7 @@ import os from unittest.mock import patch -from yt_study.config import Config +from yt_study.core.config import Config class TestConfig: diff --git a/tests/test_llm/test_generator.py b/tests/test_llm/test_generator.py index f524a03..772263a 100644 --- a/tests/test_llm/test_generator.py +++ b/tests/test_llm/test_generator.py @@ -4,8 +4,8 @@ import pytest -from yt_study.config import config -from yt_study.llm.generator import StudyMaterialGenerator +from yt_study.core.config import config +from yt_study.core.llm.generator import StudyMaterialGenerator class TestStudyMaterialGenerator: @@ -18,14 +18,14 @@ def generator(self, mock_llm_provider): def test_count_tokens_fallback(self, generator): """Test token counting fallback when library fails.""" with patch( - "yt_study.llm.generator.token_counter", side_effect=Exception("Error") + "yt_study.core.llm.generator.token_counter", side_effect=Exception("Error") ): count = generator._count_tokens("1234") assert count == 1 # 4 chars // 4 = 1 def test_chunk_transcript_small(self, generator): """Test that small transcripts are not chunked.""" - with patch("yt_study.llm.generator.token_counter", return_value=100): + with patch("yt_study.core.llm.generator.token_counter", return_value=100): chunks = generator._chunk_transcript("Small text") assert len(chunks) == 1 assert chunks[0] == "Small text" @@ -36,7 +36,7 @@ def test_chunk_transcript_sentences(self, generator): config.chunk_size = 5 # Allow room for a sentence + delimiter try: - with patch("yt_study.llm.generator.token_counter") as mock_tc: + with patch("yt_study.core.llm.generator.token_counter") as mock_tc: # 1 token per word, with the delimiter ". " adding extra tokens def count_tokens(_model, text): # noqa: ARG001 return len(text.split()) @@ -60,7 +60,7 @@ def test_chunk_transcript_newlines(self, generator): config.chunk_size = 2 try: - with patch("yt_study.llm.generator.token_counter") as mock_tc: + with patch("yt_study.core.llm.generator.token_counter") as mock_tc: mock_tc.side_effect = lambda _model, text: len(text.split()) # noqa: ARG005 # No periods, just newlines @@ -78,7 +78,7 @@ def test_chunk_transcript_hard_split(self, generator): config.chunk_size = 1 # Tiny try: - with patch("yt_study.llm.generator.token_counter") as mock_tc: + with patch("yt_study.core.llm.generator.token_counter") as mock_tc: # Mock token counter to say everything is too big mock_tc.side_effect = lambda _model, text: len(text) # noqa: ARG005 diff --git a/tests/test_llm/test_providers.py b/tests/test_llm/test_providers.py index b07e6a9..b77f5e1 100644 --- a/tests/test_llm/test_providers.py +++ b/tests/test_llm/test_providers.py @@ -4,7 +4,7 @@ import pytest -from yt_study.llm.providers import LLMGenerationError, LLMProvider, get_provider +from yt_study.core.llm.providers import LLMGenerationError, LLMProvider, get_provider class TestLLMProvider: @@ -20,7 +20,7 @@ def test_init_validation(self, mock_config): # noqa: ARG002 @pytest.mark.asyncio async def test_generate_success(self): """Test successful generation.""" - with patch("yt_study.llm.providers.acompletion") as mock_acompletion: + with patch("yt_study.core.llm.providers.acompletion") as mock_acompletion: # Setup mock response mock_response = MagicMock() mock_response.choices[0].message.content = "Generated content" @@ -43,7 +43,7 @@ async def test_generate_success(self): @pytest.mark.asyncio async def test_generate_cleanup_markdown(self): """Test cleaning of markdown code blocks from response.""" - with patch("yt_study.llm.providers.acompletion") as mock_acompletion: + with patch("yt_study.core.llm.providers.acompletion") as mock_acompletion: mock_response = MagicMock() # LLM returns content wrapped in ```markdown ... ``` mock_response.choices[ @@ -59,7 +59,7 @@ async def test_generate_cleanup_markdown(self): @pytest.mark.asyncio async def test_generate_failure(self): """Test generation failure raises custom exception.""" - with patch("yt_study.llm.providers.acompletion") as mock_acompletion: + with patch("yt_study.core.llm.providers.acompletion") as mock_acompletion: mock_acompletion.side_effect = Exception("API Error") provider = LLMProvider("gpt-4o") diff --git a/tests/test_pipeline/test_orchestrator.py b/tests/test_pipeline/test_orchestrator.py index ac71d01..7b3237b 100644 --- a/tests/test_pipeline/test_orchestrator.py +++ b/tests/test_pipeline/test_orchestrator.py @@ -4,7 +4,7 @@ import pytest -from yt_study.pipeline.orchestrator import PipelineOrchestrator, sanitize_filename +from yt_study.core.orchestrator import PipelineOrchestrator, sanitize_filename def test_sanitize_filename(): @@ -21,7 +21,7 @@ class TestPipelineOrchestrator: @pytest.fixture def orchestrator(self, temp_output_dir, mock_llm_provider): with patch( - "yt_study.pipeline.orchestrator.get_provider", + "yt_study.core.orchestrator.get_provider", return_value=mock_llm_provider, ): orch = PipelineOrchestrator(model="mock-model", output_dir=temp_output_dir) @@ -40,7 +40,8 @@ def test_validate_provider_missing_key(self, orchestrator, monkeypatch): """Test validation fails if key is missing.""" # Mock config to return key name but env var is empty with patch( - "yt_study.config.config.get_api_key_name_for_model", return_value="TEST_KEY" + "yt_study.core.config.config.get_api_key_name_for_model", + return_value="TEST_KEY", ): monkeypatch.delenv("TEST_KEY", raising=False) assert orchestrator.validate_provider() is False @@ -48,7 +49,8 @@ def test_validate_provider_missing_key(self, orchestrator, monkeypatch): def test_validate_provider_success(self, orchestrator, monkeypatch): """Test validation succeeds if key exists.""" with patch( - "yt_study.config.config.get_api_key_name_for_model", return_value="TEST_KEY" + "yt_study.core.config.config.get_api_key_name_for_model", + return_value="TEST_KEY", ): monkeypatch.setenv("TEST_KEY", "123") assert orchestrator.validate_provider() is True @@ -59,15 +61,13 @@ async def test_process_video_single(self, orchestrator): # Mock dependencies with ( patch( - "yt_study.pipeline.orchestrator.get_video_title", + "yt_study.core.orchestrator.get_video_title", return_value="Test Video", ), + patch("yt_study.core.orchestrator.get_video_duration", return_value=100), + patch("yt_study.core.orchestrator.get_video_chapters", return_value=[]), patch( - "yt_study.pipeline.orchestrator.get_video_duration", return_value=100 - ), - patch("yt_study.pipeline.orchestrator.get_video_chapters", return_value=[]), - patch( - "yt_study.pipeline.orchestrator.fetch_transcript", + "yt_study.core.orchestrator.fetch_transcript", new_callable=AsyncMock, ) as mock_fetch, ): @@ -92,22 +92,20 @@ async def test_process_video_with_chapters(self, orchestrator): # Duration > 3600 (1h) + Chapters present with ( patch( - "yt_study.pipeline.orchestrator.get_video_title", + "yt_study.core.orchestrator.get_video_title", return_value="Long Video", ), + patch("yt_study.core.orchestrator.get_video_duration", return_value=4000), patch( - "yt_study.pipeline.orchestrator.get_video_duration", return_value=4000 - ), - patch( - "yt_study.pipeline.orchestrator.get_video_chapters", + "yt_study.core.orchestrator.get_video_chapters", return_value=["chap1"], ), patch( - "yt_study.pipeline.orchestrator.fetch_transcript", + "yt_study.core.orchestrator.fetch_transcript", new_callable=AsyncMock, ) as mock_fetch, patch( - "yt_study.pipeline.orchestrator.split_transcript_by_chapters", + "yt_study.core.orchestrator.split_transcript_by_chapters", return_value={"Ch1": "text"}, ), ): @@ -133,7 +131,7 @@ async def test_process_video_with_chapters(self, orchestrator): async def test_run_video_flow(self, orchestrator): """Test run() method flow for a video URL.""" with ( - patch("yt_study.pipeline.orchestrator.parse_youtube_url") as mock_parse, + patch("yt_study.core.orchestrator.parse_youtube_url") as mock_parse, patch.object( orchestrator, "_process_with_dashboard", new_callable=AsyncMock ) as mock_dash, diff --git a/tests/test_setup_wizard.py b/tests/test_setup_wizard.py index 8718549..8f8b0b1 100644 --- a/tests/test_setup_wizard.py +++ b/tests/test_setup_wizard.py @@ -2,7 +2,7 @@ from unittest.mock import mock_open, patch -from yt_study.setup_wizard import ( +from yt_study.core.setup_wizard import ( get_api_key, get_available_models, load_config, @@ -58,10 +58,11 @@ def test_save_config(self): mock_path = Path("dummy_path") with ( patch( - "yt_study.setup_wizard.load_config", return_value={"OLD_KEY": "old_val"} + "yt_study.core.setup_wizard.load_config", + return_value={"OLD_KEY": "old_val"}, ), patch("pathlib.Path.open", mock_open()) as mock_file, - patch("yt_study.setup_wizard.get_config_path", return_value=mock_path), + patch("yt_study.core.setup_wizard.get_config_path", return_value=mock_path), ): new_config = {"NEW_KEY": "new_val", "DEFAULT_MODEL": "new_model"} save_config(new_config) @@ -151,7 +152,7 @@ def test_select_provider(self): } with ( - patch("yt_study.setup_wizard.PROVIDER_CONFIG", test_config), + patch("yt_study.core.setup_wizard.PROVIDER_CONFIG", test_config), patch("rich.prompt.Prompt.ask", return_value="2"), ): result = select_provider({"p1": [], "p2": []}) @@ -167,7 +168,7 @@ def test_select_model_pagination(self): inputs = ["n", "p", "1"] with ( - patch("yt_study.setup_wizard.PROVIDER_CONFIG", {"p1": {"name": "P1"}}), + patch("yt_study.core.setup_wizard.PROVIDER_CONFIG", {"p1": {"name": "P1"}}), patch("rich.prompt.Prompt.ask", side_effect=inputs), ): selected = select_model("p1", models) @@ -179,7 +180,8 @@ def test_select_model_gemini_prefix(self): with ( patch( - "yt_study.setup_wizard.PROVIDER_CONFIG", {"gemini": {"name": "Google"}} + "yt_study.core.setup_wizard.PROVIDER_CONFIG", + {"gemini": {"name": "Google"}}, ), patch("rich.prompt.Prompt.ask", return_value="1"), ): @@ -221,18 +223,19 @@ def test_run_setup_wizard_full_flow(self): """Test full setup flow.""" # Mocks with ( - patch("yt_study.setup_wizard.load_config", return_value={}), + patch("yt_study.core.setup_wizard.load_config", return_value={}), patch( - "yt_study.setup_wizard.get_available_models", + "yt_study.core.setup_wizard.get_available_models", return_value={"gemini": ["gemini-pro"]}, ), - patch("yt_study.setup_wizard.select_provider", return_value="gemini"), + patch("yt_study.core.setup_wizard.select_provider", return_value="gemini"), patch( - "yt_study.setup_wizard.select_model", return_value="gemini/gemini-pro" + "yt_study.core.setup_wizard.select_model", + return_value="gemini/gemini-pro", ), - patch("yt_study.setup_wizard.get_api_key", return_value="new-key"), + patch("yt_study.core.setup_wizard.get_api_key", return_value="new-key"), patch("rich.prompt.Prompt.ask", side_effect=["/custom/out", "10"]), - patch("yt_study.setup_wizard.save_config") as mock_save, + patch("yt_study.core.setup_wizard.save_config") as mock_save, ): config = run_setup_wizard(force=True) @@ -246,7 +249,10 @@ def test_run_setup_wizard_full_flow(self): def test_run_setup_wizard_skip_existing(self): """Test skipping setup if config exists.""" with ( - patch("yt_study.setup_wizard.load_config", return_value={"exists": "true"}), + patch( + "yt_study.core.setup_wizard.load_config", + return_value={"exists": "true"}, + ), patch("rich.prompt.Confirm.ask", return_value=False), ): # Do not reconfigure config = run_setup_wizard(force=False) diff --git a/tests/test_youtube/test_metadata.py b/tests/test_youtube/test_metadata.py index 2e2749b..8866714 100644 --- a/tests/test_youtube/test_metadata.py +++ b/tests/test_youtube/test_metadata.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, PropertyMock -from yt_study.youtube.metadata import ( +from yt_study.core.youtube.metadata import ( get_playlist_info, get_video_chapters, get_video_duration, diff --git a/tests/test_youtube/test_parser.py b/tests/test_youtube/test_parser.py index 33d9fef..be3362c 100644 --- a/tests/test_youtube/test_parser.py +++ b/tests/test_youtube/test_parser.py @@ -2,7 +2,7 @@ import pytest -from yt_study.youtube.parser import ( +from yt_study.core.youtube.parser import ( extract_playlist_id, extract_video_id, parse_youtube_url, diff --git a/tests/test_youtube/test_playlist.py b/tests/test_youtube/test_playlist.py index 20e3b2a..9ad1993 100644 --- a/tests/test_youtube/test_playlist.py +++ b/tests/test_youtube/test_playlist.py @@ -4,7 +4,7 @@ import pytest -from yt_study.youtube.playlist import PlaylistError, extract_playlist_videos +from yt_study.core.youtube.playlist import PlaylistError, extract_playlist_videos class TestPlaylistExtraction: diff --git a/tests/test_youtube/test_transcript.py b/tests/test_youtube/test_transcript.py index 5379d4b..f1dc78c 100644 --- a/tests/test_youtube/test_transcript.py +++ b/tests/test_youtube/test_transcript.py @@ -5,8 +5,8 @@ import pytest from youtube_transcript_api._errors import NoTranscriptFound, VideoUnavailable -from yt_study.youtube.metadata import VideoChapter -from yt_study.youtube.transcript import ( +from yt_study.core.youtube.metadata import VideoChapter +from yt_study.core.youtube.transcript import ( TranscriptError, VideoTranscript, fetch_transcript, @@ -21,7 +21,7 @@ class TestFetchTranscript: def mock_transcript_api_instance(self, mocker): """Mock the YouTubeTranscriptApi class and its instance.""" # Patch the class - mock_cls = mocker.patch("yt_study.youtube.transcript.YouTubeTranscriptApi") + mock_cls = mocker.patch("yt_study.core.youtube.transcript.YouTubeTranscriptApi") # The instance returned by constructor mock_instance = mock_cls.return_value return mock_instance From 314a84d6636e5eceea9d9bc274b47bbcaa06e33d Mon Sep 17 00:00:00 2001 From: mdabucse Date: Wed, 18 Feb 2026 18:52:46 +0530 Subject: [PATCH 3/3] refactor: decouple core from UI, fix bugs, remove duplication - Fix path traversal vulnerability in sanitize_filename (block '.' and '..') - Remove all Rich imports from core/ (generator, providers, transcript, metadata, playlist) - Replace Rich Progress/TaskID in generator with on_status callback - Create core/events.py with EventType, PipelineEvent, PipelineResult - Create ui/presenter.py (RichPipelinePresenter wraps CorePipeline) - Move setup_wizard.py from core/ to ui/wizard.py - Update cli.py to use RichPipelinePresenter and ui.wizard - Wrap blocking I/O (mkdir, write_text) in asyncio.to_thread - Delete legacy shim files: config.py, setup_wizard.py, llm/, pipeline/, prompts/, youtube/ - Delete core/orchestrator.py (replaced by ui/presenter.py) - Add tests/test_core_pipeline.py with 10 tests - Update test_cli.py, test_orchestrator.py, test_setup_wizard.py imports All 112 tests pass. --- src/yt_study/cli.py | 381 +++++-------- src/yt_study/config.py | 4 - src/yt_study/core/__init__.py | 11 +- src/yt_study/core/events.py | 49 ++ src/yt_study/core/llm/generator.py | 104 +--- src/yt_study/core/llm/providers.py | 2 - src/yt_study/core/orchestrator.py | 534 ------------------ src/yt_study/core/pipeline.py | 66 +-- src/yt_study/core/youtube/metadata.py | 2 - src/yt_study/core/youtube/playlist.py | 2 - src/yt_study/core/youtube/transcript.py | 2 - src/yt_study/llm/__init__.py | 3 - src/yt_study/llm/generator.py | 3 - src/yt_study/llm/providers.py | 3 - src/yt_study/pipeline/__init__.py | 6 - src/yt_study/pipeline/orchestrator.py | 4 - src/yt_study/prompts/__init__.py | 1 - src/yt_study/prompts/chapter_notes.py | 55 -- src/yt_study/prompts/study_notes.py | 106 ---- src/yt_study/setup_wizard.py | 4 - src/yt_study/ui/presenter.py | 287 ++++++++++ .../{core/setup_wizard.py => ui/wizard.py} | 0 src/yt_study/youtube/__init__.py | 3 - src/yt_study/youtube/metadata.py | 3 - src/yt_study/youtube/parser.py | 3 - src/yt_study/youtube/playlist.py | 3 - src/yt_study/youtube/transcript.py | 3 - tests/test_cli.py | 74 +-- tests/test_core_pipeline.py | 214 +++++++ tests/test_pipeline/test_orchestrator.py | 130 +---- tests/test_setup_wizard.py | 26 +- 31 files changed, 820 insertions(+), 1268 deletions(-) delete mode 100644 src/yt_study/config.py create mode 100644 src/yt_study/core/events.py delete mode 100644 src/yt_study/core/orchestrator.py delete mode 100644 src/yt_study/llm/__init__.py delete mode 100644 src/yt_study/llm/generator.py delete mode 100644 src/yt_study/llm/providers.py delete mode 100644 src/yt_study/pipeline/__init__.py delete mode 100644 src/yt_study/pipeline/orchestrator.py delete mode 100644 src/yt_study/prompts/__init__.py delete mode 100644 src/yt_study/prompts/chapter_notes.py delete mode 100644 src/yt_study/prompts/study_notes.py delete mode 100644 src/yt_study/setup_wizard.py create mode 100644 src/yt_study/ui/presenter.py rename src/yt_study/{core/setup_wizard.py => ui/wizard.py} (100%) delete mode 100644 src/yt_study/youtube/__init__.py delete mode 100644 src/yt_study/youtube/metadata.py delete mode 100644 src/yt_study/youtube/parser.py delete mode 100644 src/yt_study/youtube/playlist.py delete mode 100644 src/yt_study/youtube/transcript.py create mode 100644 tests/test_core_pipeline.py diff --git a/src/yt_study/cli.py b/src/yt_study/cli.py index b612fc5..7ab20bc 100644 --- a/src/yt_study/cli.py +++ b/src/yt_study/cli.py @@ -21,80 +21,63 @@ log_dir = Path.home() / ".yt-study" / "logs" try: log_dir.mkdir(parents=True, exist_ok=True) -except Exception: - # Fallback if home is not writable - log_dir = Path.cwd() / "logs" - log_dir.mkdir(parents=True, exist_ok=True) - -# Use timestamped log file for session isolation -timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -log_file = log_dir / f"yt-study-{timestamp}.log" - -root_logger = logging.getLogger() -root_logger.setLevel(logging.INFO) - -# Console Handler: Warning+, Clean output -console_handler = RichHandler(rich_tracebacks=False, show_time=False, show_path=False) -console_handler.setLevel(logging.WARNING) -root_logger.addHandler(console_handler) - -# File Handler: Debug+, Detailed format -try: - file_handler = logging.FileHandler(log_file, encoding="utf-8") - file_handler.setLevel(logging.DEBUG) - file_handler.setFormatter( - logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + log_file = log_dir / f"yt_study_{datetime.now():%Y%m%d}.log" + + # Configure root logger + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(log_file, encoding="utf-8"), + RichHandler( + rich_tracebacks=True, + show_path=False, + markup=True, + level=logging.WARNING, + ), + ], ) - root_logger.addHandler(file_handler) -except Exception: - pass +except OSError: + # Fall back to basic logging if directory creation fails + logging.basicConfig(level=logging.INFO) + app = typer.Typer( name="yt-study", - help=( - "🎓 Convert YouTube videos and playlists into comprehensive " - "study materials using AI." - ), - add_completion=True, + help="🎓 YouTube Study Material Generator • Transform videos into study notes.", + add_completion=False, rich_markup_mode="rich", ) console = Console() +# Module-level logger +logger = logging.getLogger("yt-study") + def check_config_exists() -> bool: - """Check if user configuration exists.""" + """Check if configuration file exists.""" config_path = Path.home() / ".yt-study" / "config.env" return config_path.exists() def ensure_setup() -> None: - """ - Ensure setup wizard has been run. - Triggers setup if config is missing. - """ + """Check if setup has been run, prompt if not.""" if not check_config_exists(): console.print( - "\n[yellow]⚠ No configuration found. Running setup wizard...[/yellow]\n" + "[yellow]⚠ No configuration found. Running setup wizard...[/yellow]\n" ) - try: - from .core.setup_wizard import run_setup_wizard + from .ui.wizard import run_setup_wizard - run_setup_wizard(force=False) - except ImportError as e: - console.print("[red]Critical: Could not import setup wizard.[/red]") - raise typer.Exit(code=1) from e + run_setup_wizard() -@app.command() -def process( +@app.command(name="process") +def process_command( url: Annotated[ str, typer.Argument( - help=( - "YouTube video or playlist URL, or path to a text file containing URLs." - ), - show_default=False, + help="YouTube video/playlist URL or path to batch file.", ), ], model: Annotated[ @@ -102,10 +85,7 @@ def process( typer.Option( "--model", "-m", - help=( - "LLM model (overrides config). Example: [green]gpt-4o[/green] " - "or [green]gemini/gemini-2.0-flash[/green]" - ), + help="LLM model to use. Overrides config.", ), ] = None, output: Annotated[ @@ -113,22 +93,15 @@ def process( typer.Option( "--output", "-o", - help="Output directory (overrides config).", - exists=False, - file_okay=False, - dir_okay=True, - resolve_path=True, + help="Output directory.", ), ] = None, - language: Annotated[ - list[str] | None, + lang: Annotated[ + str | None, typer.Option( - "--language", + "--lang", "-l", - help=( - "Preferred transcript languages " - "(e.g., [green]en[/green], [green]hi[/green])." - ), + help="Preferred transcript language codes (comma-separated).", ), ] = None, temperature: Annotated[ @@ -136,191 +109,151 @@ def process( typer.Option( "--temperature", "-t", - help=( - "LLM response temperature (overrides config). " - "Range: 0.0 to 1.0 (default = 0.7)" - ), - min=0.0, - max=1.0, + help="LLM temperature (0.0-1.0). Controls output randomness.", ), ] = None, max_tokens: Annotated[ int | None, typer.Option( "--max-tokens", - "-k", - help=( - "Maximum tokens for LLM responses (overrides config). " - "Adjust based on model limits. (None for model default)" - ), - min=1, + help="Maximum tokens for LLM response.", ), ] = None, ) -> None: - """ - Generate comprehensive study notes from YouTube videos or playlists. - - Supports: - \b - 1. Single Video URL - 2. Playlist URL - 3. Batch file (text file with one URL per line) - - \b - Examples: - [cyan]yt-study process "https://youtube.com/watch?v=VIDEO_ID"[/cyan] - [cyan]yt-study process "URL" -m gpt-4o[/cyan] - [cyan]yt-study process batch_urls.txt -o ./course-notes[/cyan] - """ - # Ensure configuration exists - ensure_setup() - - try: - # Lazy import for faster CLI startup - from .core.config import config - from .core.orchestrator import PipelineOrchestrator - - # Use config values as defaults, allow CLI overrides - selected_model = model or config.default_model - selected_output = output or config.default_output_dir - selected_languages = language or config.default_languages - selected_temperature = ( - temperature if temperature is not None else config.temperature - ) - selected_max_tokens = ( - max_tokens if max_tokens is not None else config.max_tokens - ) + """Process a YouTube video or playlist into study materials.""" + from rich.panel import Panel - # Create orchestrator - orchestrator = PipelineOrchestrator( - model=selected_model, - output_dir=selected_output, - languages=selected_languages, - temperature=selected_temperature, - max_tokens=selected_max_tokens, - ) - - async def run_processing() -> None: - """Determine if input is URL or file and run pipeline.""" - input_path = Path(url) - - # Check if input is an existing file (Batch Mode) - if input_path.exists() and input_path.is_file(): - # Removed redundant panel print here since dashboard handles UI - try: - # Robust encoding handling and line splitting - content = input_path.read_text(encoding="utf-8") - urls = [ - line.strip() - for line in content.splitlines() - if line.strip() and not line.strip().startswith("#") - ] - except Exception as e: - console.print( - f"[bold red]❌ Error reading batch file:[/bold red] {e}" - ) - return - - if not urls: - console.print("[yellow]⚠ Batch file is empty.[/yellow]") - return - - # Removed: console.print(f"[dim]Found {len(urls)} URLs[/dim]\n") - - for i, batch_url in enumerate(urls, 1): - # Keep this rule as it separates batch items distinctly - console.rule(f"[bold cyan]Batch Item {i}/{len(urls)}[/bold cyan]") - # Removed redundant URL print as dashboard shows title/status - try: - await orchestrator.run(batch_url) - except Exception as e: - console.print(f"[bold red]❌ Batch item failed:[/bold red] {e}") - else: - # Single URL Mode (Orchestrator handles Video vs Playlist detection) - await orchestrator.run(url) - - # Run pipeline - asyncio.run(run_processing()) - - except KeyboardInterrupt: - console.print("\n[yellow]⚠ Process interrupted by user[/yellow]") - raise typer.Exit(code=1) from None - except Exception as e: - # Import Panel locally - from rich.panel import Panel - - console.print( - Panel(f"[bold red]Fatal Error[/bold red]\n{str(e)}", border_style="red") - ) - logging.exception("Fatal error in CLI process") - raise typer.Exit(code=1) from e + from .core.config import config + ensure_setup() -@app.callback(invoke_without_command=True) -def main(ctx: typer.Context) -> None: - """ - [bold cyan]yt-study[/bold cyan]: AI-Powered Video Study Notes Generator. - - Convert YouTube content into structured Markdown notes. - """ - # Only show help if no subcommand is being invoked - if ctx.invoked_subcommand is None: - console.print(ctx.get_help()) - + # Resolve parameters from config defaults + selected_model = model or config.default_model + output_dir = output or config.default_output_dir + languages = lang.split(",") if lang else config.default_languages -@app.command() -def setup( - force: Annotated[ - bool, - typer.Option( - "--force", "-f", help="Force reconfiguration even if config exists." - ), - ] = False, -) -> None: - """ - Configure API keys and preferences interactively. - - Runs a wizard to generate the [bold]~/.yt-study/config.env[/bold] file. - """ + # Check if input is a batch file + input_path = Path(url) + if input_path.exists() and input_path.is_file(): + # Batch file mode + try: + urls = [ + line.strip() + for line in input_path.read_text(encoding="utf-8").strip().split("\n") + if line.strip() and not line.strip().startswith("#") + ] + if not urls: + console.print("[yellow]Batch file is empty.[/yellow]") + return + except OSError as e: + console.print(f"[red]Error reading batch file: {e}[/red]") + return + + console.print(f"[cyan]Processing {len(urls)} URLs from batch file...[/cyan]\n") + + from .ui.presenter import RichPipelinePresenter + + for batch_url in urls: + try: + presenter = RichPipelinePresenter( + model=selected_model, + output_dir=output_dir, + languages=languages, + temperature=temperature, + max_tokens=max_tokens, + ) + asyncio.run(presenter.run(batch_url)) + except KeyboardInterrupt: + console.print( + Panel( + "⚠ [yellow]Process interrupted by user[/yellow]", + border_style="yellow", + ) + ) + raise typer.Exit(code=1) from None + except Exception as e: + console.print(f"[red]Error processing {batch_url}: {e}[/red]") + logger.exception(f"Failed to process URL: {batch_url}") + else: + # Single URL mode + try: + from .ui.presenter import RichPipelinePresenter + + presenter = RichPipelinePresenter( + model=selected_model, + output_dir=output_dir, + languages=languages, + temperature=temperature, + max_tokens=max_tokens, + ) + asyncio.run(presenter.run(url)) + + except KeyboardInterrupt: + console.print( + Panel( + "⚠ [yellow]Process interrupted by user[/yellow]", + border_style="yellow", + ) + ) + raise typer.Exit(code=1) from None + except Exception as e: + console.print( + Panel( + f"[red bold]Fatal Error[/red bold]\n\n" + f"[white]{e}[/white]\n\n" + f"[dim]Check logs at {log_dir} for details.[/dim]", + title="💥 Error", + border_style="red", + ) + ) + logger.exception("Fatal CLI error") + raise typer.Exit(code=1) from None + + +@app.command(name="setup") +def setup_command() -> None: + """Run the interactive setup wizard.""" + from .ui.wizard import run_setup_wizard + + run_setup_wizard(force=True) + + +@app.command(name="version") +def version_command() -> None: + """Show yt-study version.""" try: - from .core.setup_wizard import run_setup_wizard + from . import __version__ - run_setup_wizard(force=force) - except ImportError as e: - console.print("[red]Setup wizard module missing.[/red]") - raise typer.Exit(code=1) from e + console.print(f"[cyan]yt-study[/cyan] version [bold]{__version__}[/bold]") + except ImportError: + console.print("[cyan]yt-study[/cyan] version [bold]0.1.0-dev[/bold]") -@app.command() -def config_path() -> None: +@app.command(name="config-path") +def config_path_command() -> None: """Show the path to the configuration file.""" - config_file = Path.home() / ".yt-study" / "config.env" + config_path = Path.home() / ".yt-study" / "config.env" - if config_file.exists(): - console.print(f"\n[cyan]Configuration file:[/cyan] {config_file}") - console.print("\n[dim]To edit: Open the file above in a text editor[/dim]") + if config_path.exists(): console.print( - "[dim]To reconfigure: Run[/dim] [cyan]yt-study setup --force[/cyan]\n" + f"[green]Configuration file:[/green] [cyan]{config_path}[/cyan]" ) else: - console.print("\n[yellow]No configuration found.[/yellow]") console.print( - "[dim]Run[/dim] [cyan]yt-study setup[/cyan] [dim]to create one.[/dim]\n" + "[yellow]No configuration found.[/yellow]\n" + "Run [cyan]yt-study setup[/cyan] to configure." ) -@app.command() -def version() -> None: - """Show version information.""" - try: - from . import __version__ - - ver = __version__ - except ImportError: - ver = "dev" - - console.print(f"[cyan]yt-study[/cyan] version [green]{ver}[/green]") +@app.callback(invoke_without_command=True) +def main( + ctx: typer.Context, +) -> None: + """🎓 YouTube Study Material Generator • Transform videos into study notes.""" + if ctx.invoked_subcommand is None: + console.print(ctx.get_help()) -if __name__ == "__main__": +def cli() -> None: + """Entry point for the CLI.""" app() diff --git a/src/yt_study/config.py b/src/yt_study/config.py deleted file mode 100644 index bc6a272..0000000 --- a/src/yt_study/config.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Backward compatibility - config moved to core.config.""" - -# Re-export everything from the new location -from yt_study.core.config import * # noqa: F401, F403 diff --git a/src/yt_study/core/__init__.py b/src/yt_study/core/__init__.py index ecd9adb..10edbf4 100644 --- a/src/yt_study/core/__init__.py +++ b/src/yt_study/core/__init__.py @@ -15,26 +15,23 @@ >>> result = await pipeline.run(["VIDEO_ID"], on_event=on_progress) """ -# Keep backward compatibility with old PipelineOrchestrator -from .orchestrator import PipelineOrchestrator -from .pipeline import ( - CorePipeline, +from .events import ( EventType, PipelineEvent, PipelineResult, +) +from .pipeline import ( + CorePipeline, run_pipeline, sanitize_filename, ) __all__ = [ - # New core API "CorePipeline", "EventType", "PipelineEvent", "PipelineResult", "run_pipeline", "sanitize_filename", - # Legacy (deprecated, for backward compatibility) - "PipelineOrchestrator", ] diff --git a/src/yt_study/core/events.py b/src/yt_study/core/events.py new file mode 100644 index 0000000..a610a85 --- /dev/null +++ b/src/yt_study/core/events.py @@ -0,0 +1,49 @@ +"""Event types and data classes for pipeline communication. + +These events allow the core pipeline to communicate progress +to any UI layer without depending on specific UI libraries. +""" + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + + +class EventType(Enum): + """Event types emitted by the pipeline.""" + + METADATA_START = "metadata_start" + METADATA_FETCHED = "metadata_fetched" + TRANSCRIPT_FETCHING = "transcript_fetching" + TRANSCRIPT_FETCHED = "transcript_fetched" + GENERATION_START = "generation_start" + CHAPTER_GENERATING = "chapter_generating" + GENERATION_COMPLETE = "generation_complete" + VIDEO_SUCCESS = "video_success" + VIDEO_FAILED = "video_failed" + PIPELINE_START = "pipeline_start" + PIPELINE_COMPLETE = "pipeline_complete" + + +@dataclass +class PipelineEvent: + """Event emitted during pipeline execution.""" + + event_type: EventType + video_id: str + title: str | None = None + chapter_number: int | None = None + total_chapters: int | None = None + error: str | None = None + output_path: Path | None = None + + +@dataclass +class PipelineResult: + """Result of pipeline execution.""" + + success_count: int + failure_count: int + total_count: int + video_ids: list[str] + errors: dict[str, str] # video_id -> error message diff --git a/src/yt_study/core/llm/generator.py b/src/yt_study/core/llm/generator.py index 0dfee2c..829e870 100644 --- a/src/yt_study/core/llm/generator.py +++ b/src/yt_study/core/llm/generator.py @@ -1,10 +1,9 @@ """Study material generator with chunking and combining logic.""" import logging +from collections.abc import Callable from litellm import token_counter -from rich.console import Console -from rich.progress import Progress, TaskID from ..config import config from ..prompts.chapter_notes import ( @@ -23,7 +22,6 @@ # Re-use system prompt for now CHAPTER_SYSTEM_PROMPT = SYSTEM_PROMPT -console = Console() logger = logging.getLogger(__name__) @@ -54,8 +52,6 @@ def __init__( def _count_tokens(self, text: str) -> int: """Count tokens in text using model-specific tokenizer.""" - # Note: token_counter might do network calls for some models or use - # local libraries (tiktoken). For efficiency, we assume it's fast. try: count = token_counter(model=self.provider.model, text=text) return int(count) if count is not None else len(text) // 4 @@ -110,8 +106,6 @@ def _chunk_transcript(self, transcript: str) -> list[str]: continue # Re-add delimiter for estimation (approximate) - # We assume '. ' was the delimiter for simplicity, logic holds - # for others mostly as we care about token count term = sentence + ". " term_tokens = self._count_tokens(term) @@ -124,7 +118,6 @@ def _chunk_transcript(self, transcript: str) -> list[str]: current_tokens = 0 # 2. Hard split the massive segment - # Estimate char limit based on token size (conservative 3 chars/token) char_limit = config.chunk_size * 3 for i in range(0, len(sentence), char_limit): sub_part = sentence[i : i + char_limit] @@ -141,7 +134,6 @@ def _chunk_transcript(self, transcript: str) -> list[str]: overlap_chunk: list[str] = [] overlap_tokens = 0 - # Take sentences from the end of current_chunk until overlap limit for prev_sent in reversed(current_chunk): prev_tokens = self._count_tokens(prev_sent) if overlap_tokens + prev_tokens <= config.chunk_overlap: @@ -153,7 +145,6 @@ def _chunk_transcript(self, transcript: str) -> list[str]: current_chunk = overlap_chunk + [sentence] current_tokens = self._count_tokens(" ".join(current_chunk)) else: - # Should be unreachable due to check above, but safe fallback current_chunk.append(sentence) current_tokens += term_tokens else: @@ -167,31 +158,11 @@ def _chunk_transcript(self, transcript: str) -> list[str]: logger.info(f"Created {len(chunks)} chunks") return chunks - def _update_status( - self, - progress: Progress | None, - task_id: TaskID | None, - video_title: str, - message: str, - ) -> None: - """Safe helper to update progress bar or log message.""" - if progress and task_id is not None: - short_title = ( - (video_title[:20] + "...") if len(video_title) > 20 else video_title - ) - # We assume the layout uses 'description' for the status text - progress.update( - task_id, description=f"[yellow]{short_title}[/yellow]: {message}" - ) - else: - logger.info(f"{video_title}: {message}") - async def generate_study_notes( self, transcript: str, video_title: str = "Video", - progress: Progress | None = None, - task_id: TaskID | None = None, + on_status: Callable[[str], None] | None = None, ) -> str: """ Generate study notes from transcript. @@ -199,17 +170,22 @@ async def generate_study_notes( Args: transcript: Full video transcript text. video_title: Video title for progress display. - progress: Optional existing progress bar instance. - task_id: Optional task ID for updating progress. + on_status: Optional callback for status updates. + Signature: (status_message: str) -> None Returns: Complete study notes in Markdown format. """ chunks = self._chunk_transcript(transcript) + def _emit(msg: str) -> None: + logger.info(f"{video_title}: {msg}") + if on_status: + on_status(msg) + # Single chunk - generate directly if len(chunks) == 1: - self._update_status(progress, task_id, video_title, "Generating notes...") + _emit("Generating notes...") notes = await self.provider.generate( system_prompt=SYSTEM_PROMPT, @@ -218,23 +194,15 @@ async def generate_study_notes( max_tokens=self.max_tokens, ) - if not progress: - logger.info(f"Generated notes for {video_title}") return notes # Multiple chunks - generate for each, then combine - self._update_status( - progress, - task_id, - video_title, - f"Generating notes for {len(chunks)} chunks...", - ) + _emit(f"Generating notes for {len(chunks)} chunks...") chunk_notes = [] for i, chunk in enumerate(chunks, 1): - msg = f"Chunk {i}/{len(chunks)} (Generating)" - self._update_status(progress, task_id, video_title, msg) + _emit(f"Chunk {i}/{len(chunks)} (Generating)") note = await self.provider.generate( system_prompt=SYSTEM_PROMPT, @@ -244,12 +212,7 @@ async def generate_study_notes( ) chunk_notes.append(note) - self._update_status( - progress, - task_id, - video_title, - f"Combining {len(chunk_notes)} chunk notes...", - ) + _emit(f"Combining {len(chunk_notes)} chunk notes...") final_notes = await self.provider.generate( system_prompt=SYSTEM_PROMPT, @@ -258,9 +221,6 @@ async def generate_study_notes( max_tokens=self.max_tokens, ) - if not progress: - logger.info(f"Completed notes for {video_title}") - return final_notes async def generate_single_chapter_notes( @@ -290,8 +250,7 @@ async def generate_chapter_based_notes( self, chapter_transcripts: dict[str, str], video_title: str = "Video", - progress: Progress | None = None, - task_id: TaskID | None = None, + on_status: Callable[[str], None] | None = None, ) -> str: """ Generate study notes using chapter-based approach. @@ -299,22 +258,18 @@ async def generate_chapter_based_notes( Args: chapter_transcripts: Dictionary mapping chapter titles to transcript text. video_title: Video title for display. - progress: Optional existing progress bar instance. - task_id: Optional task ID for updating progress. + on_status: Optional callback for status updates. Returns: Complete study notes organized by chapters. """ - # Imports are already at top-level or can be moved up, but let's - # fix the specific issue. Previously we did lazy import inside - # function which caused issues - - self._update_status( - progress, - task_id, - video_title, - f"Generating notes for {len(chapter_transcripts)} chapters...", - ) + + def _emit(msg: str) -> None: + logger.info(f"{video_title}: {msg}") + if on_status: + on_status(msg) + + _emit(f"Generating notes for {len(chapter_transcripts)} chapters...") chapter_notes = {} total_chapters = len(chapter_transcripts) @@ -322,13 +277,7 @@ async def generate_chapter_based_notes( for i, (chapter_title, chapter_text) in enumerate( chapter_transcripts.items(), 1 ): - msg = f"Chapter {i}/{total_chapters}: {chapter_title[:20]}..." - self._update_status(progress, task_id, video_title, msg) - - # If a chapter is huge, we might need recursive chunking here too. - # For now, we assume chapters are reasonably sized or the model - # can handle ~100k context. Future improvement: Check token - # count of chapter_text and recurse if needed. + _emit(f"Chapter {i}/{total_chapters}: {chapter_title[:20]}...") notes = await self.provider.generate( system_prompt=CHAPTER_SYSTEM_PROMPT, @@ -338,9 +287,7 @@ async def generate_chapter_based_notes( ) chapter_notes[chapter_title] = notes - self._update_status( - progress, task_id, video_title, "Combining chapter notes..." - ) + _emit("Combining chapter notes...") final_notes = await self.provider.generate( system_prompt=CHAPTER_SYSTEM_PROMPT, @@ -349,7 +296,4 @@ async def generate_chapter_based_notes( max_tokens=self.max_tokens, ) - if not progress: - logger.info(f"Completed chapter-based notes for {video_title}") - return final_notes diff --git a/src/yt_study/core/llm/providers.py b/src/yt_study/core/llm/providers.py index 625e45d..e17975a 100644 --- a/src/yt_study/core/llm/providers.py +++ b/src/yt_study/core/llm/providers.py @@ -5,12 +5,10 @@ from typing import Any from litellm import acompletion -from rich.console import Console from ..config import config -console = Console() logger = logging.getLogger(__name__) diff --git a/src/yt_study/core/orchestrator.py b/src/yt_study/core/orchestrator.py deleted file mode 100644 index c6c67bd..0000000 --- a/src/yt_study/core/orchestrator.py +++ /dev/null @@ -1,534 +0,0 @@ -"""Main pipeline orchestrator with concurrent processing.""" - -import asyncio -import logging -import re -from pathlib import Path - -from rich.console import Console -from rich.live import Live -from rich.panel import Panel -from rich.progress import Progress, TaskID -from rich.table import Table - -from ..ui.dashboard import PipelineDashboard -from .config import config -from .llm.generator import StudyMaterialGenerator -from .llm.providers import get_provider -from .youtube.metadata import ( - get_playlist_info, - get_video_chapters, - get_video_duration, - get_video_title, -) -from .youtube.parser import parse_youtube_url -from .youtube.playlist import extract_playlist_videos -from .youtube.transcript import ( - YouTubeIPBlockError, - fetch_transcript, - split_transcript_by_chapters, -) - - -console = Console() -logger = logging.getLogger(__name__) - - -def sanitize_filename(name: str) -> str: - """ - Sanitize a string to be used as a filename. - - Args: - name: Raw filename string. - - Returns: - Sanitized string safe for file systems. - """ - # Remove or replace invalid characters - name = re.sub(r'[<>:"/\\|?*]', "", name) - # Replace multiple spaces with single space - name = re.sub(r"\s+", " ", name) - # Trim and limit length - name = name.strip()[:100] - return name if name else "untitled" - - -class PipelineOrchestrator: - """ - Orchestrates the end-to-end pipeline for video processing. - - Manages concurrency, error handling, and UI updates. - """ - - def __init__( - self, - model: str = "gemini/gemini-2.0-flash", - output_dir: Path | None = None, - languages: list[str] | None = None, - temperature: float | None = None, - max_tokens: int | None = None, - ): - """ - Initialize orchestrator. - - Args: - model: LLM model string. - output_dir: Output directory path. - languages: Preferred transcript languages. - temperature: LLM temperature (defaults to config.temperature). - max_tokens: Max tokens (defaults to config.max_tokens). - """ - self.model = model - self.output_dir = output_dir or config.default_output_dir - self.languages = languages or config.default_languages - self.temperature = ( - temperature if temperature is not None else config.temperature - ) - self.max_tokens = max_tokens if max_tokens is not None else config.max_tokens - self.provider = get_provider(model) - self.generator = StudyMaterialGenerator( - self.provider, - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - self.semaphore = asyncio.Semaphore(config.max_concurrent_videos) - - def validate_provider(self) -> bool: - """ - Validate that the API key for the selected provider is set. - - Returns: - True if valid (or warning logged), False if critical missing config. - """ - key_name = config.get_api_key_name_for_model(self.model) - - if key_name: - import os - - if not os.environ.get(key_name): - console.print( - f"\n[red bold]✗ Missing API Key for {self.model}[/red bold]" - ) - console.print( - f"[yellow]Expected environment variable: {key_name}[/yellow]" - ) - console.print( - "[dim]Please check your .env file or run:[/dim] " - "[cyan]yt-study setup[/cyan]\n" - ) - return False - - return True - - async def process_video( - self, - video_id: str, - output_path: Path, - progress: Progress | None = None, - task_id: TaskID | None = None, - video_title: str | None = None, - is_playlist: bool = False, - ) -> bool: - """ - Process a single video: fetch transcript and generate study notes. - - Args: - video_id: YouTube Video ID. - output_path: Destination path for the MD file. - progress: Rich Progress instance. - task_id: Rich TaskID. - video_title: Pre-fetched title (optional). - is_playlist: Whether this is part of a playlist (affects UI logging). - - Returns: - True on success, False on failure. - """ - async with self.semaphore: - local_task_id = task_id - - # If standalone (not part of worker pool), create a specific - # bar if requested - if is_playlist and progress and task_id is None: - display_title = (video_title or video_id)[:30] - local_task_id = progress.add_task( - description=f"[cyan]⏳ {display_title}... (Waiting)[/cyan]", - total=None, - ) - - try: - # 1. Fetch Metadata - if not video_title: - # Run in thread to avoid blocking - video_title = await asyncio.to_thread(get_video_title, video_id) - - # Fetch duration and chapters concurrently - duration, chapters = await asyncio.gather( - asyncio.to_thread(get_video_duration, video_id), - asyncio.to_thread(get_video_chapters, video_id), - ) - - title_display = (video_title or video_id)[:40] - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=f"[cyan]📥 {title_display}... (Transcript)[/cyan]", - ) - - # 2. Fetch Transcript - transcript_obj = await fetch_transcript(video_id, self.languages) - - # 3. Determine Generation Strategy - # Use chapters if video is long (>1h) and chapters exist - use_chapters = duration > 3600 and len(chapters) > 0 and not is_playlist - - if use_chapters: - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[cyan]📖 {title_display}... (Chapters)[/cyan]" - ), - ) - # else block removed as redundant - - # Split transcript - chapter_transcripts = split_transcript_by_chapters( - transcript_obj, chapters - ) - - # Create folder for chapter notes - safe_title = sanitize_filename(video_title) - output_folder = self.output_dir / safe_title - output_folder.mkdir(parents=True, exist_ok=True) - - # Generate chapter notes - # Fix: Iterate here and call generator for each chapter - # to save individually - - for i, (chap_title, chap_text) in enumerate( - chapter_transcripts.items(), 1 - ): - status_msg = f"Chapter {i}/{len(chapter_transcripts)}" - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[cyan]🤖 {title_display}... ({status_msg})[/cyan]" - ), - ) - - notes = await self.generator.generate_single_chapter_notes( - chapter_title=chap_title, - chapter_text=chap_text, - ) - - # Save individual chapter - safe_chapter = sanitize_filename(chap_title) - chapter_file = output_folder / f"{i:02d}_{safe_chapter}.md" - chapter_file.write_text(notes, encoding="utf-8") - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=f"[green]✓ {title_display} (Done)[/green]", - completed=True, - ) - - return True - - else: - # Single file generation - transcript_text = transcript_obj.to_text() - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[cyan]🤖 {title_display}... (Generating)[/cyan]" - ), - ) - - notes = await self.generator.generate_study_notes( - transcript_text, - video_title=title_display, - progress=progress, - task_id=local_task_id, - ) - - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(notes, encoding="utf-8") - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=f"[green]✓ {title_display} (Done)[/green]", - completed=True, - ) - - return True - - except Exception as e: - logger.error(f"Failed to process {video_id}: {e}") - - err_msg = str(e) - if isinstance(e, YouTubeIPBlockError) or ( - "blocking requests" in err_msg - ): - err_display = "[bold red]IP BLOCKED[/bold red]" - console.print( - Panel( - "[bold red]🚫 YouTube IP Block Detected[/bold red]\n\n" - "YouTube is limiting requests from your IP address.\n" - "[yellow]➤ Recommendation:[/yellow] Use a VPN or " - "wait ~1 hour.", - border_style="red", - ) - ) - else: - err_display = "(Failed)" - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[red]✗ {(video_title or video_id)[:20]}... " - f"{err_display}[/red]" - ), - visible=True, - ) - - return False - - async def _process_with_dashboard( - self, - video_ids: list[str], - playlist_name: str = "Queue", - is_single_video: bool = False, - ) -> int: - """Process a list of videos using the Advanced Dashboard UI.""" - from ..ui.dashboard import PipelineDashboard - - # Initialize Dashboard FIRST to capture all output - # Adjust concurrency display: if total_videos < max_concurrency, - # only show needed workers - actual_concurrency = min(len(video_ids), config.max_concurrent_videos) - - dashboard = PipelineDashboard( - total_videos=len(video_ids), - concurrency=actual_concurrency, - playlist_name=playlist_name, - model_name=self.model, - ) - - success_count = 0 - video_titles = {} - - # Run Live Display (inline, not full screen) - # We start it immediately to show "Fetching metadata..." state - with Live(dashboard, refresh_per_second=10, console=console, screen=False): - # --- Phase 1: Metadata Fetching --- - TITLE_FETCH_CONCURRENCY = 10 - if not is_single_video: - dashboard.update_overall_status( - f"[cyan]📋 Fetching metadata for {len(video_ids)} videos...[/cyan]" - ) - - title_semaphore = asyncio.Semaphore(TITLE_FETCH_CONCURRENCY) - - async def fetch_title_safe(vid: str) -> str: - async with title_semaphore: - try: - return await asyncio.to_thread(get_video_title, vid) - except Exception: - return vid - - # Fetch titles - titles = await asyncio.gather(*(fetch_title_safe(vid) for vid in video_ids)) - video_titles = dict(zip(video_ids, titles, strict=True)) - - # --- Phase 2: Processing --- - if not is_single_video: - dashboard.update_overall_status("[bold blue]Total Progress[/bold blue]") - - # Determine base output folder - if is_single_video: - base_folder = self.output_dir - else: - base_folder = self.output_dir / sanitize_filename(playlist_name) - base_folder.mkdir(parents=True, exist_ok=True) - - # Worker Queue Implementation - queue: asyncio.Queue[str] = asyncio.Queue() - for vid in video_ids: - queue.put_nowait(vid) - - async def worker(worker_idx: int, task_id: TaskID) -> None: - nonlocal success_count - while not queue.empty(): - try: - video_id = queue.get_nowait() - except asyncio.QueueEmpty: - break - - title = video_titles.get(video_id, video_id) - safe_title = sanitize_filename(title) - - if is_single_video: - video_folder = base_folder / safe_title - output_path = video_folder / f"{safe_title}.md" - else: - output_path = base_folder / f"{safe_title}.md" - - # Update status - dashboard.update_worker( - worker_idx, f"[yellow]{title[:30]}...[/yellow]" - ) - - try: - result = await self.process_video( - video_id, - output_path, - progress=dashboard.worker_progress, - task_id=task_id, - video_title=title, - is_playlist=not is_single_video, - ) - - if result: - success_count += 1 - dashboard.add_completion(title) - else: - dashboard.add_failure(title) - - except Exception as e: - logger.error(f"Worker {worker_idx} failed on {video_id}: {e}") - dashboard.update_worker(worker_idx, f"[red]Error: {e}[/red]") - dashboard.add_failure(title) - await asyncio.sleep(2) - finally: - queue.task_done() - - # Worker done - dashboard.update_worker(worker_idx, "[dim]Idle[/dim]") - - try: - workers = [ - asyncio.create_task(worker(i, dashboard.worker_tasks[i])) - for i in range(actual_concurrency) - ] - await asyncio.gather(*workers) - except Exception as e: - logger.error(f"Dashboard execution failed: {e}") - - # Print summary table after dashboard closes - self._print_summary(dashboard) - - return success_count - - def _print_summary(self, dashboard: "PipelineDashboard") -> None: - """Print a summary table of the run.""" - if not dashboard.recent_completions and not dashboard.recent_failures: - return - - summary_table = Table( - title="📊 Processing Summary", - border_style="cyan", - show_header=True, - header_style="bold magenta", - ) - summary_table.add_column("Status", justify="center") - summary_table.add_column("Video Title", style="dim") - - # Add failures first (more important) - if dashboard.recent_failures: - for fail in dashboard.recent_failures: - summary_table.add_row("[bold red]FAILED[/bold red]", fail) - - # Add successes - if dashboard.recent_completions: - for comp in dashboard.recent_completions: - summary_table.add_row("[green]SUCCESS[/green]", comp) - - console.print("\n") - console.print(summary_table) - console.print( - f"\n[bold]Total Completed:[/bold] " - f"{dashboard.overall_progress.tasks[0].completed}/" - f"{dashboard.overall_progress.tasks[0].total}" - ) - console.print("[dim]Check logs for detailed error reports.[/dim]\n") - - async def process_playlist( - self, playlist_id: str, playlist_name: str = "playlist" - ) -> int: - """Process playlist with concurrent dynamic progress bars.""" - video_ids = await extract_playlist_videos(playlist_id) - return await self._process_with_dashboard(video_ids, playlist_name) - - async def run(self, url: str) -> None: - """ - Run the pipeline for a given YouTube URL. - - Args: - url: YouTube video or playlist URL. - """ - # Validate Provider Credentials - if not self.validate_provider(): - return - - try: - # Parse URL - parsed = parse_youtube_url(url) - - if parsed.url_type == "video": - if not parsed.video_id: - console.print("[red]Error: Video ID could not be extracted[/red]") - return - - await self._process_with_dashboard( - [parsed.video_id], - playlist_name="Single Video", - is_single_video=True, - ) - - # Summary is already printed by _process_with_dashboard - - elif parsed.url_type == "playlist": - if not parsed.playlist_id: - console.print( - "[red]Error: Playlist ID could not be extracted[/red]" - ) - return - - # Fetch basic playlist info first - handled in dashboard now - # if needed or kept minimal. Actually, playlist title fetching - # is useful to show BEFORE starting but _process_with_dashboard - # fetches metadata anyway. - # However, to pass playlist_name to dashboard, we might want it. - # But waiting for title can be slow. - # Let's let the dashboard handle titles for videos. - # For playlist title, we can try fast fetch or default to ID. - - # Fetching playlist title here is blocking/slow if not careful. - # Let's just use ID as name initially or fetch it quickly. - # The original code did fetch it. - - # To reduce redundancy, we remove the print statement - # "Playlist: ..." - playlist_title, _ = await asyncio.to_thread( - get_playlist_info, parsed.playlist_id - ) - - # Removed redundant print: - # console.print(f"[cyan]📑 Playlist:[/cyan] {playlist_title}\n") - - await self.process_playlist(parsed.playlist_id, playlist_title) - - # Summary handled by dashboard - - except ValueError as e: - console.print(f"[red]Input Error: {e}[/red]") - except Exception as e: - console.print(f"[red]Unexpected Error: {e}[/red]") - logger.exception("Pipeline run failed") diff --git a/src/yt_study/core/pipeline.py b/src/yt_study/core/pipeline.py index 73dfcff..68ef63c 100644 --- a/src/yt_study/core/pipeline.py +++ b/src/yt_study/core/pipeline.py @@ -10,11 +10,10 @@ import logging import re from collections.abc import Callable -from dataclasses import dataclass -from enum import Enum from pathlib import Path from .config import config +from .events import EventType, PipelineEvent, PipelineResult from .llm.generator import StudyMaterialGenerator from .llm.providers import get_provider from .youtube.metadata import ( @@ -32,46 +31,6 @@ logger = logging.getLogger(__name__) -class EventType(Enum): - """Event types emitted by the pipeline.""" - - METADATA_START = "metadata_start" - METADATA_FETCHED = "metadata_fetched" - TRANSCRIPT_FETCHING = "transcript_fetching" - TRANSCRIPT_FETCHED = "transcript_fetched" - GENERATION_START = "generation_start" - CHAPTER_GENERATING = "chapter_generating" - GENERATION_COMPLETE = "generation_complete" - VIDEO_SUCCESS = "video_success" - VIDEO_FAILED = "video_failed" - PIPELINE_START = "pipeline_start" - PIPELINE_COMPLETE = "pipeline_complete" - - -@dataclass -class PipelineEvent: - """Event emitted during pipeline execution.""" - - event_type: EventType - video_id: str - title: str | None = None - chapter_number: int | None = None - total_chapters: int | None = None - error: str | None = None - output_path: Path | None = None - - -@dataclass -class PipelineResult: - """Result of pipeline execution.""" - - success_count: int - failure_count: int - total_count: int - video_ids: list[str] - errors: dict[str, str] # video_id -> error message - - def sanitize_filename(name: str) -> str: """ Sanitize a string to be used as a filename. @@ -85,6 +44,11 @@ def sanitize_filename(name: str) -> str: name = re.sub(r'[<>:"/\\|?*]', "", name) name = re.sub(r"\s+", " ", name) name = name.strip()[:100] + + # Prevent directory traversal + if name in {".", ".."}: + return "untitled" + return name if name else "untitled" @@ -209,7 +173,9 @@ async def _process_single_video( safe_title = sanitize_filename(title) output_folder = self.output_dir / safe_title - output_folder.mkdir(parents=True, exist_ok=True) + await asyncio.to_thread( + lambda: output_folder.mkdir(parents=True, exist_ok=True) + ) total_chapters = len(chapter_transcripts) @@ -231,7 +197,9 @@ async def _process_single_video( safe_chapter = sanitize_filename(chap_title) chapter_file = output_folder / f"{i:02d}_{safe_chapter}.md" - chapter_file.write_text(notes, encoding="utf-8") + await asyncio.to_thread( + chapter_file.write_text, notes, encoding="utf-8" + ) emit( EventType.GENERATION_COMPLETE, @@ -252,8 +220,12 @@ async def _process_single_video( video_title=title, ) - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(notes, encoding="utf-8") + await asyncio.to_thread( + lambda: output_path.parent.mkdir(parents=True, exist_ok=True) + ) + await asyncio.to_thread( + output_path.write_text, notes, encoding="utf-8" + ) emit( EventType.GENERATION_COMPLETE, @@ -320,8 +292,6 @@ async def run( on_event: Callable[[PipelineEvent], None] | None = None, ) -> PipelineResult: """ - ✅ SINGLE ENTRY POINT FOR CLI AND OTHER FRONTENDS - Process a list of video IDs concurrently. Args: diff --git a/src/yt_study/core/youtube/metadata.py b/src/yt_study/core/youtube/metadata.py index 14f2bca..c0cd834 100644 --- a/src/yt_study/core/youtube/metadata.py +++ b/src/yt_study/core/youtube/metadata.py @@ -5,10 +5,8 @@ from typing import Any from pytubefix import Playlist, YouTube -from rich.console import Console -console = Console() logger = logging.getLogger(__name__) diff --git a/src/yt_study/core/youtube/playlist.py b/src/yt_study/core/youtube/playlist.py index 706867a..6f4b354 100644 --- a/src/yt_study/core/youtube/playlist.py +++ b/src/yt_study/core/youtube/playlist.py @@ -4,10 +4,8 @@ import logging from pytubefix import Playlist -from rich.console import Console -console = Console() logger = logging.getLogger(__name__) diff --git a/src/yt_study/core/youtube/transcript.py b/src/yt_study/core/youtube/transcript.py index 1f8a2ac..b825d69 100644 --- a/src/yt_study/core/youtube/transcript.py +++ b/src/yt_study/core/youtube/transcript.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from typing import Any -from rich.console import Console from youtube_transcript_api import YouTubeTranscriptApi from youtube_transcript_api._errors import ( IpBlocked, @@ -18,7 +17,6 @@ from .metadata import VideoChapter -console = Console() logger = logging.getLogger(__name__) diff --git a/src/yt_study/llm/__init__.py b/src/yt_study/llm/__init__.py deleted file mode 100644 index 258ecae..0000000 --- a/src/yt_study/llm/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Backward compatibility - llm package moved to core.llm.""" - -from yt_study.core.llm import * # noqa: F401, F403 diff --git a/src/yt_study/llm/generator.py b/src/yt_study/llm/generator.py deleted file mode 100644 index 37e9917..0000000 --- a/src/yt_study/llm/generator.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Backward compatibility - llm.generator moved to core.llm.generator.""" - -from yt_study.core.llm.generator import * # noqa: F401, F403 diff --git a/src/yt_study/llm/providers.py b/src/yt_study/llm/providers.py deleted file mode 100644 index 0e1eeb2..0000000 --- a/src/yt_study/llm/providers.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Backward compatibility - llm.providers moved to core.llm.providers.""" - -from yt_study.core.llm.providers import * # noqa: F401, F403 diff --git a/src/yt_study/pipeline/__init__.py b/src/yt_study/pipeline/__init__.py deleted file mode 100644 index b20f6e9..0000000 --- a/src/yt_study/pipeline/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Backward compatibility - pipeline package moved to core.""" - -from yt_study.core.orchestrator import PipelineOrchestrator, sanitize_filename - - -__all__ = ["PipelineOrchestrator", "sanitize_filename"] diff --git a/src/yt_study/pipeline/orchestrator.py b/src/yt_study/pipeline/orchestrator.py deleted file mode 100644 index 7128c6b..0000000 --- a/src/yt_study/pipeline/orchestrator.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Backward compatibility - pipeline.orchestrator moved to core.orchestrator.""" - -# Re-export everything from the new location -from yt_study.core.orchestrator import * # noqa: F401, F403 diff --git a/src/yt_study/prompts/__init__.py b/src/yt_study/prompts/__init__.py deleted file mode 100644 index 5bc08d8..0000000 --- a/src/yt_study/prompts/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Prompt templates for study material generation.""" diff --git a/src/yt_study/prompts/chapter_notes.py b/src/yt_study/prompts/chapter_notes.py deleted file mode 100644 index 2e88328..0000000 --- a/src/yt_study/prompts/chapter_notes.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Prompt templates for chapter-based study material generation.""" - -# Prompt for generating notes from a single chapter -CHAPTER_GENERATION_PROMPT = """ -Create an in-depth, detailed study guide for this specific chapter: - -Chapter Title: {chapter_title} - -Transcript: -{transcript_chunk} - -Requirements: -1. **Deep Dive**: Provide a thorough, granular explanation of the chapter's topic. -2. **Comprehensive**: Include every nuance, sub-point, and detail mentioned. -3. **Clarify Concepts**: Explain "why" and "how" for every concept, not just "what". -4. **Examples**: Preserve all examples and use them to illustrate technical points. -5. **Structure**: Use deeply nested headers (###, ####) to break down complex ideas. -6. Pure Markdown format. -7. English language. -8. **DO NOT include any opening or closing conversational text.** -9. **Start directly with the first header (e.g., # Chapter Title)**""" - - -# Prompt for combining chapter notes -COMBINE_CHAPTER_NOTES_PROMPT = """ -You have generated study notes for different chapters of the same video. -Combine these chapter notes into a single, well-organized study document. - -Video chapters and notes: -{chapter_notes} - -Requirements: -1. Keep chapter structure with clear headers (## Chapter Title) -2. Ensure logical flow between chapters -3. Remove redundancies while preserving all unique information -4. Add a brief introduction summarizing what the video covers -5. Maintain all important details from every chapter -6. Use proper Markdown hierarchy (##, ###, etc.) -7. Do NOT add a table of contents -8. Create a cohesive document that's easy to navigate and review""" - - -def get_chapter_prompt(chapter_title: str, transcript_chunk: str) -> str: - """Generate prompt for a chapter.""" - return CHAPTER_GENERATION_PROMPT.format( - chapter_title=chapter_title, transcript_chunk=transcript_chunk - ) - - -def get_combine_chapters_prompt(chapter_notes: dict[str, str]) -> str: - """Generate prompt for combining chapter notes.""" - combined = "\n\n".join( - [f"## {title}\n\n{notes}" for title, notes in chapter_notes.items()] - ) - return COMBINE_CHAPTER_NOTES_PROMPT.format(chapter_notes=combined) diff --git a/src/yt_study/prompts/study_notes.py b/src/yt_study/prompts/study_notes.py deleted file mode 100644 index 02b3f6d..0000000 --- a/src/yt_study/prompts/study_notes.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Prompt templates for study material generation and chunk combining.""" - -# System prompt for generating study notes from transcript chunks -SYSTEM_PROMPT = """ -You are an expert academic tutor and technical writer dedicated to creating -the most comprehensive study materials possible. - -Your goal is to transform video transcripts into deep, detailed, and highly -structured study notes. -You prioritize: -- **Depth**: Go beyond surface-level summaries. Explain *why* and *how*, not - just *what*. -- **Comprehensive Coverage**: Capture every single concept, detail, nuance, - and example mentioned. -- **Clarity**: Use clear, academic yet accessible language. Break down complex topics. -- **Structure**: Use logical hierarchy (headers, subheaders) to organize - information effectively. - -Always generate output in clean Markdown format.""" - -# User prompt for individual transcript chunks -CHUNK_GENERATION_PROMPT = """ -Create extremely detailed and in-depth study notes from this transcript -segment: - -{transcript_chunk} - -Requirements: -1. **Comprehensive Coverage**: Cover EVERY concept, definition, theory, and - significant detail mentioned. Do not summarize; expand. -2. **In-Depth Explanation**: Explain complex ideas thoroughly. If a process - is described, break it down step-by-step. -3. **Capture Examples & Code**: Include ALL examples, case studies, and - especially **CODE BLOCKS/SQL** provided in the transcript. -4. **Technical Precision**: Use actual SQL syntax for table definitions - (e.g., `CREATE TABLE`), not just descriptions. -5. **Logical Structure**: Use deep hierarchy (##, ###, ####) to organize - related concepts. -6. **Key Terminology**: Highlight and define technical terms or important vocabulary. -7. **Pure Markdown**: No HTML, no table of contents. -8. **Clean Start**: Start directly with the content headers, no conversational filler. -9. **Language**: English.""" - -# Prompt for combining multiple chunk notes into final document -COMBINE_CHUNKS_PROMPT = """ -You have generated study notes for multiple segments of the same video. Now -combine these segments into a single, coherent study document. - -Segment notes: -{chunk_notes} - -Requirements: -1. Merge all segments into a unified, flowing document -2. **Preserve ALL Content**: Do NOT summarize or condense. Retain all - explanations, examples, code blocks, and details. -3. **Preserve Code & Syntax**: Use valid `CREATE TABLE` SQL and other - specific syntax exactly as presented. -4. **Seamless Merge**: Connect segments smoothly, but do not delete content for brevity. -5. **Detailed & Comprehensive**: The final document must be as detailed as - the input segments combined. -6. Maintain consistent formatting and structure (##, ###). -7. Do NOT add a table of contents. -8. **Example clean output:** "# Title\\n\\n## Section 1..." - -Create study notes that are comprehensive, well-organized, and easy to review.""" - -# Prompt for single-pass generation (small transcripts) -SINGLE_PASS_PROMPT = """ -Create an extensive and in-depth study guide from this complete video -transcript: - -{transcript} - -Requirements: -1. **Exhaustive Coverage**: Cover every single topic discussed. Do not leave - out details. -2. **Deep Understanding**: Explain concepts clearly and thoroughly, as if - teaching a student. -3. **Structured Learning**: Use a clear, logical hierarchy (##, ###, ####) - to organize topics. -4. **Examples & Context**: Retain all illustrative examples and context - provided in the video. -5. **No Summarization**: Do not summarize brief points; expand them for full - understanding. -6. Pure Markdown format (no HTML, no table of contents). -7. English language output. -8. **Clean Start**: Start directly with the first header (e.g. # Video - Title), no filler.""" - - -def get_chunk_prompt(transcript_chunk: str) -> str: - """Generate prompt for a transcript chunk.""" - return CHUNK_GENERATION_PROMPT.format(transcript_chunk=transcript_chunk) - - -def get_combine_prompt(chunk_notes: list[str]) -> str: - """Generate prompt for combining chunk notes.""" - combined = "\n\n---\n\n".join( - [f"## Segment {i + 1}\n\n{note}" for i, note in enumerate(chunk_notes)] - ) - return COMBINE_CHUNKS_PROMPT.format(chunk_notes=combined) - - -def get_single_pass_prompt(transcript: str) -> str: - """Generate prompt for single-pass generation.""" - return SINGLE_PASS_PROMPT.format(transcript=transcript) diff --git a/src/yt_study/setup_wizard.py b/src/yt_study/setup_wizard.py deleted file mode 100644 index 823d111..0000000 --- a/src/yt_study/setup_wizard.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Backward compatibility - setup_wizard moved to core.setup_wizard.""" - -# Re-export everything from the new location -from yt_study.core.setup_wizard import * # noqa: F401, F403 diff --git a/src/yt_study/ui/presenter.py b/src/yt_study/ui/presenter.py new file mode 100644 index 0000000..ad96b64 --- /dev/null +++ b/src/yt_study/ui/presenter.py @@ -0,0 +1,287 @@ +""" +Rich TUI presenter for the pipeline. + +This module bridges CorePipeline (pure logic) with the Rich Dashboard (UI). +It handles URL parsing, playlist detection, and maps pipeline events to +dashboard updates. +""" + +import asyncio +import logging + +from rich.console import Console +from rich.live import Live +from rich.panel import Panel +from rich.table import Table + +from ..core.config import config +from ..core.events import EventType, PipelineEvent +from ..core.pipeline import CorePipeline, sanitize_filename +from ..core.youtube.metadata import get_playlist_info, get_video_title +from ..core.youtube.parser import parse_youtube_url +from ..core.youtube.playlist import extract_playlist_videos +from .dashboard import PipelineDashboard + + +console = Console() +logger = logging.getLogger(__name__) + + +class RichPipelinePresenter: + """ + Adapts CorePipeline events to a Rich Dashboard. + + Acts as the glue between business logic (CorePipeline) and + the terminal UI (PipelineDashboard + Rich Live). + """ + + def __init__( + self, + model: str = "gemini/gemini-2.0-flash", + output_dir=None, + languages=None, + temperature=None, + max_tokens=None, + ): + """ + Initialize presenter with a CorePipeline. + + Args: + model: LLM model string. + output_dir: Output directory path. + languages: Preferred transcript languages. + temperature: LLM temperature. + max_tokens: Max tokens for generation. + """ + self.model = model + self.pipeline = CorePipeline( + model=model, + output_dir=output_dir, + languages=languages, + temperature=temperature, + max_tokens=max_tokens, + ) + self.dashboard: PipelineDashboard | None = None + self.video_titles: dict[str, str] = {} + + def validate_provider(self) -> bool: + """ + Validate that the API key for the selected provider is set. + + Returns: + True if valid, False if missing. + """ + key_name = config.get_api_key_name_for_model(self.model) + + if key_name: + import os + + if not os.environ.get(key_name): + console.print( + f"\n[red bold]✗ Missing API Key for {self.model}[/red bold]" + ) + console.print( + f"[yellow]Expected environment variable: {key_name}[/yellow]" + ) + console.print( + "[dim]Please check your .env file or run:[/dim] " + "[cyan]yt-study setup[/cyan]\n" + ) + return False + + return True + + def _handle_event(self, event: PipelineEvent) -> None: + """Map CorePipeline events to Dashboard updates.""" + if not self.dashboard: + return + + if event.event_type == EventType.VIDEO_SUCCESS: + title = event.title or event.video_id + self.dashboard.add_completion(title) + + elif event.event_type == EventType.VIDEO_FAILED: + title = event.title or event.video_id + self.dashboard.add_failure(title) + + # Special handling for IP blocks + if event.error and "IP blocked" in event.error: + console.print( + Panel( + "[bold red]🚫 YouTube IP Block Detected[/bold red]\n\n" + "YouTube is limiting requests from your IP address.\n" + "[yellow]➤ Recommendation:[/yellow] Use a VPN or " + "wait ~1 hour.", + border_style="red", + ) + ) + + async def _process_with_dashboard( + self, + video_ids: list[str], + playlist_name: str = "Queue", + is_single_video: bool = False, + ) -> int: + """Process a list of videos using the Dashboard UI.""" + actual_concurrency = min(len(video_ids), config.max_concurrent_videos) + + self.dashboard = PipelineDashboard( + total_videos=len(video_ids), + concurrency=actual_concurrency, + playlist_name=playlist_name, + model_name=self.model, + ) + + # --- Phase 1: Metadata Fetching --- + with Live(self.dashboard, refresh_per_second=10, console=console, screen=False): + TITLE_FETCH_CONCURRENCY = 10 + if not is_single_video: + self.dashboard.update_overall_status( + f"[cyan]📋 Fetching metadata for {len(video_ids)} videos...[/cyan]" + ) + + title_semaphore = asyncio.Semaphore(TITLE_FETCH_CONCURRENCY) + + async def fetch_title_safe(vid: str) -> str: + async with title_semaphore: + try: + return await asyncio.to_thread(get_video_title, vid) + except Exception: + return vid + + titles = await asyncio.gather( + *(fetch_title_safe(vid) for vid in video_ids) + ) + self.video_titles = dict(zip(video_ids, titles, strict=True)) + + # --- Phase 2: Processing --- + if not is_single_video: + self.dashboard.update_overall_status( + "[bold blue]Total Progress[/bold blue]" + ) + + # Determine base output folder + if is_single_video: + base_folder = self.pipeline.output_dir + else: + base_folder = self.pipeline.output_dir / sanitize_filename( + playlist_name + ) + base_folder.mkdir(parents=True, exist_ok=True) + + # Override pipeline output_dir for this run + original_output = self.pipeline.output_dir + + if is_single_video: + # For single video, create a subfolder with the video title + for vid in video_ids: + title = self.video_titles.get(vid, vid) + safe_title = sanitize_filename(title) + video_folder = base_folder / safe_title + self.pipeline.output_dir = video_folder + else: + self.pipeline.output_dir = base_folder + + # Update worker display + for i, vid in enumerate(video_ids[:actual_concurrency]): + title = self.video_titles.get(vid, vid) + self.dashboard.update_worker( + i, f"[yellow]{title[:30]}...[/yellow]" + ) + + result = await self.pipeline.run(video_ids, on_event=self._handle_event) + + # Restore original output dir + self.pipeline.output_dir = original_output + + # Print summary table after dashboard closes + self._print_summary() + + return result.success_count + + def _print_summary(self) -> None: + """Print a summary table of the run.""" + if not self.dashboard: + return + + if ( + not self.dashboard.recent_completions + and not self.dashboard.recent_failures + ): + return + + summary_table = Table( + title="📊 Processing Summary", + border_style="cyan", + show_header=True, + header_style="bold magenta", + ) + summary_table.add_column("Status", justify="center") + summary_table.add_column("Video Title", style="dim") + + # Add failures first (more important) + if self.dashboard.recent_failures: + for fail in self.dashboard.recent_failures: + summary_table.add_row("[bold red]FAILED[/bold red]", fail) + + # Add successes + if self.dashboard.recent_completions: + for comp in self.dashboard.recent_completions: + summary_table.add_row("[green]SUCCESS[/green]", comp) + + console.print("\n") + console.print(summary_table) + console.print( + f"\n[bold]Total Completed:[/bold] " + f"{self.dashboard.overall_progress.tasks[0].completed}/" + f"{self.dashboard.overall_progress.tasks[0].total}" + ) + console.print("[dim]Check logs for detailed error reports.[/dim]\n") + + async def run(self, url: str) -> None: + """ + Run the pipeline for a given YouTube URL. + + Args: + url: YouTube video or playlist URL. + """ + # Validate Provider Credentials + if not self.validate_provider(): + return + + try: + # Parse URL + parsed = parse_youtube_url(url) + + if parsed.url_type == "video": + if not parsed.video_id: + console.print( + "[red]Error: Video ID could not be extracted[/red]" + ) + return + + await self._process_with_dashboard( + [parsed.video_id], + playlist_name="Single Video", + is_single_video=True, + ) + + elif parsed.url_type == "playlist": + if not parsed.playlist_id: + console.print( + "[red]Error: Playlist ID could not be extracted[/red]" + ) + return + + playlist_title, _ = await asyncio.to_thread( + get_playlist_info, parsed.playlist_id + ) + + video_ids = await extract_playlist_videos(parsed.playlist_id) + await self._process_with_dashboard(video_ids, playlist_title) + + except ValueError as e: + console.print(f"[red]Input Error: {e}[/red]") + except Exception as e: + console.print(f"[red]Unexpected Error: {e}[/red]") + logger.exception("Pipeline run failed") diff --git a/src/yt_study/core/setup_wizard.py b/src/yt_study/ui/wizard.py similarity index 100% rename from src/yt_study/core/setup_wizard.py rename to src/yt_study/ui/wizard.py diff --git a/src/yt_study/youtube/__init__.py b/src/yt_study/youtube/__init__.py deleted file mode 100644 index 9d869cb..0000000 --- a/src/yt_study/youtube/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Backward compatibility - youtube package moved to core.youtube.""" - -from yt_study.core.youtube import * # noqa: F401, F403 diff --git a/src/yt_study/youtube/metadata.py b/src/yt_study/youtube/metadata.py deleted file mode 100644 index e5f1206..0000000 --- a/src/yt_study/youtube/metadata.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Backward compatibility - youtube.metadata moved to core.youtube.metadata.""" - -from yt_study.core.youtube.metadata import * # noqa: F401, F403 diff --git a/src/yt_study/youtube/parser.py b/src/yt_study/youtube/parser.py deleted file mode 100644 index 5a572b0..0000000 --- a/src/yt_study/youtube/parser.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Backward compatibility - youtube.parser moved to core.youtube.parser.""" - -from yt_study.core.youtube.parser import * # noqa: F401, F403 diff --git a/src/yt_study/youtube/playlist.py b/src/yt_study/youtube/playlist.py deleted file mode 100644 index 9a09b31..0000000 --- a/src/yt_study/youtube/playlist.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Backward compatibility - youtube.playlist moved to core.youtube.playlist.""" - -from yt_study.core.youtube.playlist import * # noqa: F401, F403 diff --git a/src/yt_study/youtube/transcript.py b/src/yt_study/youtube/transcript.py deleted file mode 100644 index 392156b..0000000 --- a/src/yt_study/youtube/transcript.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Backward compatibility - youtube.transcript moved to core.youtube.transcript.""" - -from yt_study.core.youtube.transcript import * # noqa: F401, F403 diff --git a/tests/test_cli.py b/tests/test_cli.py index c9c04c6..286863e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,9 +12,9 @@ @pytest.fixture -def mock_orchestrator(): # noqa: ARG001 - # Patch where PipelineOrchestrator is defined - with patch("yt_study.core.orchestrator.PipelineOrchestrator") as mock: +def mock_presenter(): # noqa: ARG001 + # Patch where RichPipelinePresenter is used in CLI + with patch("yt_study.ui.presenter.RichPipelinePresenter") as mock: instance = mock.return_value instance.run = AsyncMock() yield mock @@ -36,12 +36,6 @@ def test_version(): def test_version_import_error(): """Test version command handles missing __version__ gracefully.""" with patch.dict("sys.modules", {"yt_study": None}): - # Mocking import error for specific attribute is tricky with sys.modules - # simpler to patch the import statement inside cli.py if possible, - # or just assume the fallback logic works if __version__ is missing. - # Let's try patching builtins.__import__ specifically for that - # module? Too complex. - # Just manually call the function? No, tested via runner. pass @@ -61,15 +55,15 @@ def test_config_path_missing(): assert "No configuration found" in result.stdout -def test_process_url_success(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_url_success(mock_config_exists, mock_presenter): # noqa: ARG001 """Test processing a simple URL.""" result = runner.invoke(app, ["process", "https://youtube.com/watch?v=123"]) assert result.exit_code == 0 - mock_orchestrator.return_value.run.assert_awaited() + mock_presenter.return_value.run.assert_awaited() -def test_process_batch_file(mock_config_exists, mock_orchestrator, tmp_path): # noqa: ARG001 +def test_process_batch_file(mock_config_exists, mock_presenter, tmp_path): # noqa: ARG001 """Test processing a batch file.""" batch_file = tmp_path / "urls.txt" batch_file.write_text("https://yt.com/v1\nhttps://yt.com/v2") @@ -77,10 +71,10 @@ def test_process_batch_file(mock_config_exists, mock_orchestrator, tmp_path): # result = runner.invoke(app, ["process", str(batch_file)]) assert result.exit_code == 0 - assert mock_orchestrator.return_value.run.await_count == 2 + assert mock_presenter.return_value.run.await_count == 2 -def test_process_batch_file_empty(mock_config_exists, mock_orchestrator, tmp_path): # noqa: ARG001 +def test_process_batch_file_empty(mock_config_exists, mock_presenter, tmp_path): # noqa: ARG001 """Test processing an empty batch file.""" batch_file = tmp_path / "empty.txt" batch_file.write_text("") @@ -89,10 +83,10 @@ def test_process_batch_file_empty(mock_config_exists, mock_orchestrator, tmp_pat assert result.exit_code == 0 assert "Batch file is empty" in result.stdout - mock_orchestrator.return_value.run.assert_not_awaited() + mock_presenter.return_value.run.assert_not_awaited() -def test_process_batch_file_error(mock_config_exists, mock_orchestrator, tmp_path): # noqa: ARG001 +def test_process_batch_file_error(mock_config_exists, mock_presenter, tmp_path): # noqa: ARG001 """Test error reading batch file.""" batch_file = tmp_path / "restricted.txt" batch_file.touch() @@ -101,12 +95,7 @@ def test_process_batch_file_error(mock_config_exists, mock_orchestrator, tmp_pat with patch("pathlib.Path.read_text", side_effect=OSError("Access denied")): result = runner.invoke(app, ["process", str(batch_file)]) - assert ( - result.exit_code == 0 - ) # It returns early, exit code 0 usually unless exception propagates - # Wait, cli.py does return, so exit code 0 is correct for Typer - # unless we raise Exit. - # Checks stdout + assert result.exit_code == 0 assert "Error reading batch file" in result.stdout @@ -114,28 +103,25 @@ def test_process_missing_config(): """Test that missing config triggers setup check/error.""" with ( patch("yt_study.cli.check_config_exists", return_value=False), - patch("yt_study.core.setup_wizard.run_setup_wizard") as mock_setup, + patch("yt_study.ui.wizard.run_setup_wizard") as mock_setup, ): runner.invoke(app, ["process", "url"]) mock_setup.assert_called_once() -def test_process_keyboard_interrupt(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_keyboard_interrupt(mock_config_exists, mock_presenter): # noqa: ARG001 """Test handling of KeyboardInterrupt.""" - mock_orchestrator.return_value.run.side_effect = KeyboardInterrupt() + mock_presenter.return_value.run.side_effect = KeyboardInterrupt() result = runner.invoke(app, ["process", "url"]) assert result.exit_code == 1 - # Check for Rich Panel content format - # Rich markup uses symbols like ⚠ which might be encoded differently - # Let's match partial string content "Process interrupted" assert "Process interrupted by user" in result.stdout -def test_process_general_exception(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_general_exception(mock_config_exists, mock_presenter): # noqa: ARG001 """Test handling of general exceptions.""" - mock_orchestrator.return_value.run.side_effect = Exception("Boom") + mock_presenter.return_value.run.side_effect = Exception("Boom") result = runner.invoke(app, ["process", "url"]) @@ -146,7 +132,7 @@ def test_process_general_exception(mock_config_exists, mock_orchestrator): # no def test_setup_command(): """Test setup command triggers wizard.""" - with patch("yt_study.core.setup_wizard.run_setup_wizard") as mock_wizard: + with patch("yt_study.ui.wizard.run_setup_wizard") as mock_wizard: result = runner.invoke(app, ["setup"]) assert result.exit_code == 0 mock_wizard.assert_called_once() @@ -154,10 +140,7 @@ def test_setup_command(): def test_setup_import_error(): """Test setup command handling missing wizard module.""" - # Simulate ImportError when importing setup_wizard - with patch.dict("sys.modules", {"yt_study.core.setup_wizard": None}): - # This approach is tricky because we are inside the test process. - # Better to patch the specific import or function call if lazy. + with patch.dict("sys.modules", {"yt_study.ui.wizard": None}): pass @@ -165,9 +148,8 @@ def test_ensure_setup_import_error(): """Test ensure_setup handles missing wizard.""" with ( patch("yt_study.cli.check_config_exists", return_value=False), - patch.dict("sys.modules", {"yt_study.core.setup_wizard": None}), + patch.dict("sys.modules", {"yt_study.ui.wizard": None}), ): - # This won't work easily as the module is likely already imported. pass @@ -178,7 +160,7 @@ def test_callback_help(): assert "Usage" in result.stdout -def test_process_with_temperature_flag(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_with_temperature_flag(mock_config_exists, mock_presenter): # noqa: ARG001 """Test processing with custom temperature parameter.""" result = runner.invoke( app, @@ -186,12 +168,12 @@ def test_process_with_temperature_flag(mock_config_exists, mock_orchestrator): ) assert result.exit_code == 0 - call_kwargs = mock_orchestrator.call_args[1] + call_kwargs = mock_presenter.call_args[1] assert call_kwargs["temperature"] == 0.5 - mock_orchestrator.return_value.run.assert_awaited() + mock_presenter.return_value.run.assert_awaited() -def test_process_with_max_tokens_flag(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_with_max_tokens_flag(mock_config_exists, mock_presenter): # noqa: ARG001 """Test processing with custom max_tokens parameter.""" result = runner.invoke( app, @@ -199,12 +181,12 @@ def test_process_with_max_tokens_flag(mock_config_exists, mock_orchestrator): # ) assert result.exit_code == 0 - call_kwargs = mock_orchestrator.call_args[1] + call_kwargs = mock_presenter.call_args[1] assert call_kwargs["max_tokens"] == 2000 - mock_orchestrator.return_value.run.assert_awaited() + mock_presenter.return_value.run.assert_awaited() -def test_process_with_temperature_and_max_tokens(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_with_temperature_and_max_tokens(mock_config_exists, mock_presenter): # noqa: ARG001 """Test processing with both temperature and max_tokens parameters.""" result = runner.invoke( app, @@ -219,7 +201,7 @@ def test_process_with_temperature_and_max_tokens(mock_config_exists, mock_orches ) assert result.exit_code == 0 - call_kwargs = mock_orchestrator.call_args[1] + call_kwargs = mock_presenter.call_args[1] assert call_kwargs["temperature"] == 0.8 assert call_kwargs["max_tokens"] == 3000 - mock_orchestrator.return_value.run.assert_awaited() + mock_presenter.return_value.run.assert_awaited() diff --git a/tests/test_core_pipeline.py b/tests/test_core_pipeline.py new file mode 100644 index 0000000..51bdf23 --- /dev/null +++ b/tests/test_core_pipeline.py @@ -0,0 +1,214 @@ +"""Tests for the core pipeline.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from yt_study.core.events import EventType, PipelineResult +from yt_study.core.pipeline import CorePipeline, sanitize_filename + + +class TestSanitizeFilename: + """Test filename sanitization.""" + + def test_basic_name(self): + assert sanitize_filename("Hello World") == "Hello World" + + def test_special_characters(self): + assert sanitize_filename('foo/bar:baz"qux') == "foobarbazqux" + + def test_whitespace_normalization(self): + assert sanitize_filename(" spaces ") == "spaces" + + def test_length_limit(self): + assert len(sanitize_filename("a" * 200)) == 100 + + def test_empty_returns_untitled(self): + assert sanitize_filename("") == "untitled" + + def test_dot_traversal(self): + """Prevent directory traversal with '.' title.""" + assert sanitize_filename(".") == "untitled" + + def test_dotdot_traversal(self): + """Prevent directory traversal with '..' title.""" + assert sanitize_filename("..") == "untitled" + + def test_only_special_chars(self): + """All chars removed → returns untitled.""" + assert sanitize_filename('<>:"/\\|?*') == "untitled" + + +class TestCorePipelineInit: + """Test pipeline initialization.""" + + def test_defaults(self, temp_output_dir, mock_llm_provider): + with patch( + "yt_study.core.pipeline.get_provider", + return_value=mock_llm_provider, + ): + pipeline = CorePipeline( + model="mock-model", output_dir=temp_output_dir + ) + assert pipeline.model == "mock-model" + assert pipeline.output_dir == temp_output_dir + + def test_custom_params(self, temp_output_dir, mock_llm_provider): + with patch( + "yt_study.core.pipeline.get_provider", + return_value=mock_llm_provider, + ): + pipeline = CorePipeline( + model="mock-model", + output_dir=temp_output_dir, + temperature=0.5, + max_tokens=1000, + ) + assert pipeline.temperature == 0.5 + assert pipeline.max_tokens == 1000 + + +class TestCorePipelineRun: + """Test pipeline execution.""" + + @pytest.fixture + def pipeline(self, temp_output_dir, mock_llm_provider): + with patch( + "yt_study.core.pipeline.get_provider", + return_value=mock_llm_provider, + ): + p = CorePipeline(model="mock-model", output_dir=temp_output_dir) + p.generator = MagicMock() + p.generator.generate_study_notes = AsyncMock(return_value="# Notes") + p.generator.generate_single_chapter_notes = AsyncMock( + return_value="# Chapter Notes" + ) + return p + + @pytest.mark.asyncio + async def test_empty_list(self, pipeline): + """Empty video list returns zero counts.""" + result = await pipeline.run([]) + assert result.success_count == 0 + assert result.failure_count == 0 + assert result.total_count == 0 + + @pytest.mark.asyncio + async def test_success(self, pipeline): + """Successful processing emits VIDEO_SUCCESS and writes file.""" + events = [] + + with ( + patch( + "yt_study.core.pipeline.get_video_title", + return_value="Test Video", + ), + patch("yt_study.core.pipeline.get_video_duration", return_value=100), + patch("yt_study.core.pipeline.get_video_chapters", return_value=[]), + patch( + "yt_study.core.pipeline.fetch_transcript", + new_callable=AsyncMock, + ) as mock_fetch, + ): + mock_transcript = MagicMock() + mock_transcript.to_text.return_value = "Transcript text" + mock_fetch.return_value = mock_transcript + + result = await pipeline.run( + ["vid123"], + on_event=lambda e: events.append(e), + ) + + assert result.success_count == 1 + assert result.failure_count == 0 + + # Check events + event_types = [e.event_type for e in events] + assert EventType.PIPELINE_START in event_types + assert EventType.VIDEO_SUCCESS in event_types + assert EventType.PIPELINE_COMPLETE in event_types + + @pytest.mark.asyncio + async def test_failure_emits_event(self, pipeline): + """Exception during processing emits VIDEO_FAILED.""" + events = [] + + with ( + patch( + "yt_study.core.pipeline.get_video_title", + side_effect=Exception("Network error"), + ), + ): + result = await pipeline.run( + ["bad_vid"], + on_event=lambda e: events.append(e), + ) + + assert result.success_count == 0 + assert result.failure_count == 1 + assert "bad_vid" in result.errors + + event_types = [e.event_type for e in events] + assert EventType.VIDEO_FAILED in event_types + + @pytest.mark.asyncio + async def test_missing_api_key(self, pipeline, monkeypatch): + """Missing API key fails all videos.""" + with patch( + "yt_study.core.config.config.get_api_key_name_for_model", + return_value="TEST_KEY", + ): + monkeypatch.delenv("TEST_KEY", raising=False) + result = await pipeline.run(["vid1", "vid2"]) + + assert result.success_count == 0 + assert result.failure_count == 2 + assert result.total_count == 2 + + @pytest.mark.asyncio + async def test_with_chapters(self, pipeline): + """Test processing a video with chapters.""" + with ( + patch( + "yt_study.core.pipeline.get_video_title", + return_value="Long Video", + ), + patch("yt_study.core.pipeline.get_video_duration", return_value=4000), + patch( + "yt_study.core.pipeline.get_video_chapters", + return_value=["chap1"], + ), + patch( + "yt_study.core.pipeline.fetch_transcript", + new_callable=AsyncMock, + ) as mock_fetch, + patch( + "yt_study.core.pipeline.split_transcript_by_chapters", + return_value={"Ch1": "text"}, + ), + ): + mock_transcript = MagicMock() + mock_fetch.return_value = mock_transcript + + result = await pipeline.run(["vid123"]) + + assert result.success_count == 1 + + # Verify folder creation + expected_folder = pipeline.output_dir / "Long Video" + assert expected_folder.exists() + assert (expected_folder / "01_Ch1.md").exists() + + @pytest.mark.asyncio + async def test_pipeline_result_structure(self, pipeline): + """Verify PipelineResult has correct structure.""" + with patch.object( + pipeline, "_process_single_video", new_callable=AsyncMock + ) as mock_proc: + mock_proc.return_value = True + + result = await pipeline.run(["v1", "v2"]) + + assert isinstance(result, PipelineResult) + assert result.total_count == 2 + assert result.video_ids == ["v1", "v2"] diff --git a/tests/test_pipeline/test_orchestrator.py b/tests/test_pipeline/test_orchestrator.py index 7b3237b..bbf42c2 100644 --- a/tests/test_pipeline/test_orchestrator.py +++ b/tests/test_pipeline/test_orchestrator.py @@ -1,139 +1,61 @@ -"""Tests for pipeline orchestrator.""" +"""Tests for pipeline presenter (Rich UI wrapper).""" from unittest.mock import AsyncMock, MagicMock, patch import pytest -from yt_study.core.orchestrator import PipelineOrchestrator, sanitize_filename +from yt_study.ui.presenter import RichPipelinePresenter -def test_sanitize_filename(): - """Test filename sanitization.""" - assert sanitize_filename("Hello World") == "Hello World" - assert sanitize_filename("foo/bar:baz") == "foobarbaz" - assert sanitize_filename(" spaces ") == "spaces" - assert len(sanitize_filename("a" * 200)) == 100 - - -class TestPipelineOrchestrator: - """Test orchestrator logic.""" +class TestRichPipelinePresenter: + """Test presenter logic.""" @pytest.fixture - def orchestrator(self, temp_output_dir, mock_llm_provider): + def presenter(self, temp_output_dir, mock_llm_provider): with patch( - "yt_study.core.orchestrator.get_provider", + "yt_study.core.pipeline.get_provider", return_value=mock_llm_provider, ): - orch = PipelineOrchestrator(model="mock-model", output_dir=temp_output_dir) - # Mock the generator inside - orch.generator = MagicMock() - orch.generator.generate_study_notes = AsyncMock(return_value="# Notes") - orch.generator.generate_single_chapter_notes = AsyncMock( - return_value="# Chapter Notes" + pres = RichPipelinePresenter( + model="mock-model", output_dir=temp_output_dir ) - orch.generator.provider = ( - mock_llm_provider # needed for direct calls in chapter loop + # Mock the inner pipeline's generator + pres.pipeline.generator = MagicMock() + pres.pipeline.generator.generate_study_notes = AsyncMock( + return_value="# Notes" ) - return orch + pres.pipeline.generator.generate_single_chapter_notes = AsyncMock( + return_value="# Chapter Notes" + ) + return pres - def test_validate_provider_missing_key(self, orchestrator, monkeypatch): + def test_validate_provider_missing_key(self, presenter, monkeypatch): """Test validation fails if key is missing.""" - # Mock config to return key name but env var is empty with patch( "yt_study.core.config.config.get_api_key_name_for_model", return_value="TEST_KEY", ): monkeypatch.delenv("TEST_KEY", raising=False) - assert orchestrator.validate_provider() is False + assert presenter.validate_provider() is False - def test_validate_provider_success(self, orchestrator, monkeypatch): + def test_validate_provider_success(self, presenter, monkeypatch): """Test validation succeeds if key exists.""" with patch( "yt_study.core.config.config.get_api_key_name_for_model", return_value="TEST_KEY", ): monkeypatch.setenv("TEST_KEY", "123") - assert orchestrator.validate_provider() is True - - @pytest.mark.asyncio - async def test_process_video_single(self, orchestrator): - """Test processing a single video (no chapters).""" - # Mock dependencies - with ( - patch( - "yt_study.core.orchestrator.get_video_title", - return_value="Test Video", - ), - patch("yt_study.core.orchestrator.get_video_duration", return_value=100), - patch("yt_study.core.orchestrator.get_video_chapters", return_value=[]), - patch( - "yt_study.core.orchestrator.fetch_transcript", - new_callable=AsyncMock, - ) as mock_fetch, - ): - mock_transcript = MagicMock() - mock_transcript.to_text.return_value = "Transcript text" - mock_fetch.return_value = mock_transcript - - video_id = "vid123" - output_path = orchestrator.output_dir / "notes.md" - - success = await orchestrator.process_video(video_id, output_path) - - assert success is True - assert output_path.exists() - assert output_path.read_text(encoding="utf-8") == "# Notes" - orchestrator.generator.generate_study_notes.assert_called_once() + assert presenter.validate_provider() is True @pytest.mark.asyncio - async def test_process_video_with_chapters(self, orchestrator): - """Test processing a video with chapters.""" - # Mock dependencies - # Duration > 3600 (1h) + Chapters present - with ( - patch( - "yt_study.core.orchestrator.get_video_title", - return_value="Long Video", - ), - patch("yt_study.core.orchestrator.get_video_duration", return_value=4000), - patch( - "yt_study.core.orchestrator.get_video_chapters", - return_value=["chap1"], - ), - patch( - "yt_study.core.orchestrator.fetch_transcript", - new_callable=AsyncMock, - ) as mock_fetch, - patch( - "yt_study.core.orchestrator.split_transcript_by_chapters", - return_value={"Ch1": "text"}, - ), - ): - mock_transcript = MagicMock() - mock_fetch.return_value = mock_transcript - - video_id = "vid123" - output_path = ( - orchestrator.output_dir / "ignored.md" - ) # Folder structure used instead - - success = await orchestrator.process_video(video_id, output_path) - - assert success is True - # Verify folder creation - expected_folder = orchestrator.output_dir / "Long Video" - assert expected_folder.exists() - # Verify individual chapter file created (mock provider returns - # default text) - assert (expected_folder / "01_Ch1.md").exists() - - @pytest.mark.asyncio - async def test_run_video_flow(self, orchestrator): + async def test_run_video_flow(self, presenter): """Test run() method flow for a video URL.""" with ( - patch("yt_study.core.orchestrator.parse_youtube_url") as mock_parse, + patch( + "yt_study.ui.presenter.parse_youtube_url" + ) as mock_parse, patch.object( - orchestrator, "_process_with_dashboard", new_callable=AsyncMock + presenter, "_process_with_dashboard", new_callable=AsyncMock ) as mock_dash, ): mock_parsed = MagicMock() @@ -143,7 +65,7 @@ async def test_run_video_flow(self, orchestrator): mock_dash.return_value = 1 # 1 success - await orchestrator.run("http://url") + await presenter.run("http://url") mock_dash.assert_called_once() args = mock_dash.call_args diff --git a/tests/test_setup_wizard.py b/tests/test_setup_wizard.py index 8f8b0b1..f89f8fb 100644 --- a/tests/test_setup_wizard.py +++ b/tests/test_setup_wizard.py @@ -2,7 +2,7 @@ from unittest.mock import mock_open, patch -from yt_study.core.setup_wizard import ( +from yt_study.ui.wizard import ( get_api_key, get_available_models, load_config, @@ -58,11 +58,11 @@ def test_save_config(self): mock_path = Path("dummy_path") with ( patch( - "yt_study.core.setup_wizard.load_config", + "yt_study.ui.wizard.load_config", return_value={"OLD_KEY": "old_val"}, ), patch("pathlib.Path.open", mock_open()) as mock_file, - patch("yt_study.core.setup_wizard.get_config_path", return_value=mock_path), + patch("yt_study.ui.wizard.get_config_path", return_value=mock_path), ): new_config = {"NEW_KEY": "new_val", "DEFAULT_MODEL": "new_model"} save_config(new_config) @@ -152,7 +152,7 @@ def test_select_provider(self): } with ( - patch("yt_study.core.setup_wizard.PROVIDER_CONFIG", test_config), + patch("yt_study.ui.wizard.PROVIDER_CONFIG", test_config), patch("rich.prompt.Prompt.ask", return_value="2"), ): result = select_provider({"p1": [], "p2": []}) @@ -168,7 +168,7 @@ def test_select_model_pagination(self): inputs = ["n", "p", "1"] with ( - patch("yt_study.core.setup_wizard.PROVIDER_CONFIG", {"p1": {"name": "P1"}}), + patch("yt_study.ui.wizard.PROVIDER_CONFIG", {"p1": {"name": "P1"}}), patch("rich.prompt.Prompt.ask", side_effect=inputs), ): selected = select_model("p1", models) @@ -180,7 +180,7 @@ def test_select_model_gemini_prefix(self): with ( patch( - "yt_study.core.setup_wizard.PROVIDER_CONFIG", + "yt_study.ui.wizard.PROVIDER_CONFIG", {"gemini": {"name": "Google"}}, ), patch("rich.prompt.Prompt.ask", return_value="1"), @@ -223,19 +223,19 @@ def test_run_setup_wizard_full_flow(self): """Test full setup flow.""" # Mocks with ( - patch("yt_study.core.setup_wizard.load_config", return_value={}), + patch("yt_study.ui.wizard.load_config", return_value={}), patch( - "yt_study.core.setup_wizard.get_available_models", + "yt_study.ui.wizard.get_available_models", return_value={"gemini": ["gemini-pro"]}, ), - patch("yt_study.core.setup_wizard.select_provider", return_value="gemini"), + patch("yt_study.ui.wizard.select_provider", return_value="gemini"), patch( - "yt_study.core.setup_wizard.select_model", + "yt_study.ui.wizard.select_model", return_value="gemini/gemini-pro", ), - patch("yt_study.core.setup_wizard.get_api_key", return_value="new-key"), + patch("yt_study.ui.wizard.get_api_key", return_value="new-key"), patch("rich.prompt.Prompt.ask", side_effect=["/custom/out", "10"]), - patch("yt_study.core.setup_wizard.save_config") as mock_save, + patch("yt_study.ui.wizard.save_config") as mock_save, ): config = run_setup_wizard(force=True) @@ -250,7 +250,7 @@ def test_run_setup_wizard_skip_existing(self): """Test skipping setup if config exists.""" with ( patch( - "yt_study.core.setup_wizard.load_config", + "yt_study.ui.wizard.load_config", return_value={"exists": "true"}, ), patch("rich.prompt.Confirm.ask", return_value=False),