diff --git a/.env.example b/.env.example index f82d014..3f855fb 100644 --- a/.env.example +++ b/.env.example @@ -2,7 +2,7 @@ # Copy this to .env and fill in your API keys # Required: At least one LLM API key (Gemini is the default provider) -GEMINI_API_KEY=your_gemini_key_here +GEMINI_API_KEY="your_gemini_api_key_here" # Optional: Other LLM Providers # OPENAI_API_KEY=your_openai_key_here diff --git a/.gitignore b/.gitignore index 1dfecf9..217717a 100644 --- a/.gitignore +++ b/.gitignore @@ -16,5 +16,8 @@ nul drafts/ .omc .claude + +# Generated content +src/yt_study/output/ coverage.xml htmlcov/ diff --git a/src/yt_study/cli.py b/src/yt_study/cli.py index 622e8fc..7ab20bc 100644 --- a/src/yt_study/cli.py +++ b/src/yt_study/cli.py @@ -21,80 +21,63 @@ log_dir = Path.home() / ".yt-study" / "logs" try: log_dir.mkdir(parents=True, exist_ok=True) -except Exception: - # Fallback if home is not writable - log_dir = Path.cwd() / "logs" - log_dir.mkdir(parents=True, exist_ok=True) - -# Use timestamped log file for session isolation -timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") -log_file = log_dir / f"yt-study-{timestamp}.log" - -root_logger = logging.getLogger() -root_logger.setLevel(logging.INFO) - -# Console Handler: Warning+, Clean output -console_handler = RichHandler(rich_tracebacks=False, show_time=False, show_path=False) -console_handler.setLevel(logging.WARNING) -root_logger.addHandler(console_handler) - -# File Handler: Debug+, Detailed format -try: - file_handler = logging.FileHandler(log_file, encoding="utf-8") - file_handler.setLevel(logging.DEBUG) - file_handler.setFormatter( - logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + log_file = log_dir / f"yt_study_{datetime.now():%Y%m%d}.log" + + # Configure root logger + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(log_file, encoding="utf-8"), + RichHandler( + rich_tracebacks=True, + show_path=False, + markup=True, + level=logging.WARNING, + ), + ], ) - root_logger.addHandler(file_handler) -except Exception: - pass +except OSError: + # Fall back to basic logging if directory creation fails + logging.basicConfig(level=logging.INFO) + app = typer.Typer( name="yt-study", - help=( - "🎓 Convert YouTube videos and playlists into comprehensive " - "study materials using AI." - ), - add_completion=True, + help="🎓 YouTube Study Material Generator • Transform videos into study notes.", + add_completion=False, rich_markup_mode="rich", ) console = Console() +# Module-level logger +logger = logging.getLogger("yt-study") + def check_config_exists() -> bool: - """Check if user configuration exists.""" + """Check if configuration file exists.""" config_path = Path.home() / ".yt-study" / "config.env" return config_path.exists() def ensure_setup() -> None: - """ - Ensure setup wizard has been run. - Triggers setup if config is missing. - """ + """Check if setup has been run, prompt if not.""" if not check_config_exists(): console.print( - "\n[yellow]⚠ No configuration found. Running setup wizard...[/yellow]\n" + "[yellow]⚠ No configuration found. Running setup wizard...[/yellow]\n" ) - try: - from .setup_wizard import run_setup_wizard + from .ui.wizard import run_setup_wizard - run_setup_wizard(force=False) - except ImportError as e: - console.print("[red]Critical: Could not import setup wizard.[/red]") - raise typer.Exit(code=1) from e + run_setup_wizard() -@app.command() -def process( +@app.command(name="process") +def process_command( url: Annotated[ str, typer.Argument( - help=( - "YouTube video or playlist URL, or path to a text file containing URLs." - ), - show_default=False, + help="YouTube video/playlist URL or path to batch file.", ), ], model: Annotated[ @@ -102,10 +85,7 @@ def process( typer.Option( "--model", "-m", - help=( - "LLM model (overrides config). Example: [green]gpt-4o[/green] " - "or [green]gemini/gemini-2.0-flash[/green]" - ), + help="LLM model to use. Overrides config.", ), ] = None, output: Annotated[ @@ -113,22 +93,15 @@ def process( typer.Option( "--output", "-o", - help="Output directory (overrides config).", - exists=False, - file_okay=False, - dir_okay=True, - resolve_path=True, + help="Output directory.", ), ] = None, - language: Annotated[ - list[str] | None, + lang: Annotated[ + str | None, typer.Option( - "--language", + "--lang", "-l", - help=( - "Preferred transcript languages " - "(e.g., [green]en[/green], [green]hi[/green])." - ), + help="Preferred transcript language codes (comma-separated).", ), ] = None, temperature: Annotated[ @@ -136,191 +109,151 @@ def process( typer.Option( "--temperature", "-t", - help=( - "LLM response temperature (overrides config). " - "Range: 0.0 to 1.0 (default = 0.7)" - ), - min=0.0, - max=1.0, + help="LLM temperature (0.0-1.0). Controls output randomness.", ), ] = None, max_tokens: Annotated[ int | None, typer.Option( "--max-tokens", - "-k", - help=( - "Maximum tokens for LLM responses (overrides config). " - "Adjust based on model limits. (None for model default)" - ), - min=1, + help="Maximum tokens for LLM response.", ), ] = None, ) -> None: - """ - Generate comprehensive study notes from YouTube videos or playlists. - - Supports: - \b - 1. Single Video URL - 2. Playlist URL - 3. Batch file (text file with one URL per line) - - \b - Examples: - [cyan]yt-study process "https://youtube.com/watch?v=VIDEO_ID"[/cyan] - [cyan]yt-study process "URL" -m gpt-4o[/cyan] - [cyan]yt-study process batch_urls.txt -o ./course-notes[/cyan] - """ - # Ensure configuration exists - ensure_setup() - - try: - # Lazy import for faster CLI startup - from .config import config - from .pipeline.orchestrator import PipelineOrchestrator - - # Use config values as defaults, allow CLI overrides - selected_model = model or config.default_model - selected_output = output or config.default_output_dir - selected_languages = language or config.default_languages - selected_temperature = ( - temperature if temperature is not None else config.temperature - ) - selected_max_tokens = ( - max_tokens if max_tokens is not None else config.max_tokens - ) + """Process a YouTube video or playlist into study materials.""" + from rich.panel import Panel - # Create orchestrator - orchestrator = PipelineOrchestrator( - model=selected_model, - output_dir=selected_output, - languages=selected_languages, - temperature=selected_temperature, - max_tokens=selected_max_tokens, - ) - - async def run_processing() -> None: - """Determine if input is URL or file and run pipeline.""" - input_path = Path(url) - - # Check if input is an existing file (Batch Mode) - if input_path.exists() and input_path.is_file(): - # Removed redundant panel print here since dashboard handles UI - try: - # Robust encoding handling and line splitting - content = input_path.read_text(encoding="utf-8") - urls = [ - line.strip() - for line in content.splitlines() - if line.strip() and not line.strip().startswith("#") - ] - except Exception as e: - console.print( - f"[bold red]❌ Error reading batch file:[/bold red] {e}" - ) - return - - if not urls: - console.print("[yellow]⚠ Batch file is empty.[/yellow]") - return - - # Removed: console.print(f"[dim]Found {len(urls)} URLs[/dim]\n") - - for i, batch_url in enumerate(urls, 1): - # Keep this rule as it separates batch items distinctly - console.rule(f"[bold cyan]Batch Item {i}/{len(urls)}[/bold cyan]") - # Removed redundant URL print as dashboard shows title/status - try: - await orchestrator.run(batch_url) - except Exception as e: - console.print(f"[bold red]❌ Batch item failed:[/bold red] {e}") - else: - # Single URL Mode (Orchestrator handles Video vs Playlist detection) - await orchestrator.run(url) - - # Run pipeline - asyncio.run(run_processing()) - - except KeyboardInterrupt: - console.print("\n[yellow]⚠ Process interrupted by user[/yellow]") - raise typer.Exit(code=1) from None - except Exception as e: - # Import Panel locally - from rich.panel import Panel - - console.print( - Panel(f"[bold red]Fatal Error[/bold red]\n{str(e)}", border_style="red") - ) - logging.exception("Fatal error in CLI process") - raise typer.Exit(code=1) from e + from .core.config import config + ensure_setup() -@app.callback(invoke_without_command=True) -def main(ctx: typer.Context) -> None: - """ - [bold cyan]yt-study[/bold cyan]: AI-Powered Video Study Notes Generator. - - Convert YouTube content into structured Markdown notes. - """ - # Only show help if no subcommand is being invoked - if ctx.invoked_subcommand is None: - console.print(ctx.get_help()) - + # Resolve parameters from config defaults + selected_model = model or config.default_model + output_dir = output or config.default_output_dir + languages = lang.split(",") if lang else config.default_languages -@app.command() -def setup( - force: Annotated[ - bool, - typer.Option( - "--force", "-f", help="Force reconfiguration even if config exists." - ), - ] = False, -) -> None: - """ - Configure API keys and preferences interactively. - - Runs a wizard to generate the [bold]~/.yt-study/config.env[/bold] file. - """ + # Check if input is a batch file + input_path = Path(url) + if input_path.exists() and input_path.is_file(): + # Batch file mode + try: + urls = [ + line.strip() + for line in input_path.read_text(encoding="utf-8").strip().split("\n") + if line.strip() and not line.strip().startswith("#") + ] + if not urls: + console.print("[yellow]Batch file is empty.[/yellow]") + return + except OSError as e: + console.print(f"[red]Error reading batch file: {e}[/red]") + return + + console.print(f"[cyan]Processing {len(urls)} URLs from batch file...[/cyan]\n") + + from .ui.presenter import RichPipelinePresenter + + for batch_url in urls: + try: + presenter = RichPipelinePresenter( + model=selected_model, + output_dir=output_dir, + languages=languages, + temperature=temperature, + max_tokens=max_tokens, + ) + asyncio.run(presenter.run(batch_url)) + except KeyboardInterrupt: + console.print( + Panel( + "⚠ [yellow]Process interrupted by user[/yellow]", + border_style="yellow", + ) + ) + raise typer.Exit(code=1) from None + except Exception as e: + console.print(f"[red]Error processing {batch_url}: {e}[/red]") + logger.exception(f"Failed to process URL: {batch_url}") + else: + # Single URL mode + try: + from .ui.presenter import RichPipelinePresenter + + presenter = RichPipelinePresenter( + model=selected_model, + output_dir=output_dir, + languages=languages, + temperature=temperature, + max_tokens=max_tokens, + ) + asyncio.run(presenter.run(url)) + + except KeyboardInterrupt: + console.print( + Panel( + "⚠ [yellow]Process interrupted by user[/yellow]", + border_style="yellow", + ) + ) + raise typer.Exit(code=1) from None + except Exception as e: + console.print( + Panel( + f"[red bold]Fatal Error[/red bold]\n\n" + f"[white]{e}[/white]\n\n" + f"[dim]Check logs at {log_dir} for details.[/dim]", + title="💥 Error", + border_style="red", + ) + ) + logger.exception("Fatal CLI error") + raise typer.Exit(code=1) from None + + +@app.command(name="setup") +def setup_command() -> None: + """Run the interactive setup wizard.""" + from .ui.wizard import run_setup_wizard + + run_setup_wizard(force=True) + + +@app.command(name="version") +def version_command() -> None: + """Show yt-study version.""" try: - from .setup_wizard import run_setup_wizard + from . import __version__ - run_setup_wizard(force=force) - except ImportError as e: - console.print("[red]Setup wizard module missing.[/red]") - raise typer.Exit(code=1) from e + console.print(f"[cyan]yt-study[/cyan] version [bold]{__version__}[/bold]") + except ImportError: + console.print("[cyan]yt-study[/cyan] version [bold]0.1.0-dev[/bold]") -@app.command() -def config_path() -> None: +@app.command(name="config-path") +def config_path_command() -> None: """Show the path to the configuration file.""" - config_file = Path.home() / ".yt-study" / "config.env" + config_path = Path.home() / ".yt-study" / "config.env" - if config_file.exists(): - console.print(f"\n[cyan]Configuration file:[/cyan] {config_file}") - console.print("\n[dim]To edit: Open the file above in a text editor[/dim]") + if config_path.exists(): console.print( - "[dim]To reconfigure: Run[/dim] [cyan]yt-study setup --force[/cyan]\n" + f"[green]Configuration file:[/green] [cyan]{config_path}[/cyan]" ) else: - console.print("\n[yellow]No configuration found.[/yellow]") console.print( - "[dim]Run[/dim] [cyan]yt-study setup[/cyan] [dim]to create one.[/dim]\n" + "[yellow]No configuration found.[/yellow]\n" + "Run [cyan]yt-study setup[/cyan] to configure." ) -@app.command() -def version() -> None: - """Show version information.""" - try: - from . import __version__ - - ver = __version__ - except ImportError: - ver = "dev" - - console.print(f"[cyan]yt-study[/cyan] version [green]{ver}[/green]") +@app.callback(invoke_without_command=True) +def main( + ctx: typer.Context, +) -> None: + """🎓 YouTube Study Material Generator • Transform videos into study notes.""" + if ctx.invoked_subcommand is None: + console.print(ctx.get_help()) -if __name__ == "__main__": +def cli() -> None: + """Entry point for the CLI.""" app() diff --git a/src/yt_study/core/__init__.py b/src/yt_study/core/__init__.py new file mode 100644 index 0000000..10edbf4 --- /dev/null +++ b/src/yt_study/core/__init__.py @@ -0,0 +1,37 @@ +"""Core pipeline module - zero UI dependencies. + +This module provides the core pipeline functionality that can be used +by any frontend (CLI, web API, GUI, etc.) without UI dependencies. + +Usage: + >>> from yt_study.core import CorePipeline, EventType + >>> + >>> pipeline = CorePipeline(model="gemini-1.5-flash") + >>> + >>> def on_progress(event): + ... if event.event_type == EventType.VIDEO_SUCCESS: + ... print(f"Done: {event.title}") + >>> + >>> result = await pipeline.run(["VIDEO_ID"], on_event=on_progress) +""" + +from .events import ( + EventType, + PipelineEvent, + PipelineResult, +) +from .pipeline import ( + CorePipeline, + run_pipeline, + sanitize_filename, +) + + +__all__ = [ + "CorePipeline", + "EventType", + "PipelineEvent", + "PipelineResult", + "run_pipeline", + "sanitize_filename", +] diff --git a/src/yt_study/config.py b/src/yt_study/core/config.py similarity index 100% rename from src/yt_study/config.py rename to src/yt_study/core/config.py diff --git a/src/yt_study/core/events.py b/src/yt_study/core/events.py new file mode 100644 index 0000000..a610a85 --- /dev/null +++ b/src/yt_study/core/events.py @@ -0,0 +1,49 @@ +"""Event types and data classes for pipeline communication. + +These events allow the core pipeline to communicate progress +to any UI layer without depending on specific UI libraries. +""" + +from dataclasses import dataclass +from enum import Enum +from pathlib import Path + + +class EventType(Enum): + """Event types emitted by the pipeline.""" + + METADATA_START = "metadata_start" + METADATA_FETCHED = "metadata_fetched" + TRANSCRIPT_FETCHING = "transcript_fetching" + TRANSCRIPT_FETCHED = "transcript_fetched" + GENERATION_START = "generation_start" + CHAPTER_GENERATING = "chapter_generating" + GENERATION_COMPLETE = "generation_complete" + VIDEO_SUCCESS = "video_success" + VIDEO_FAILED = "video_failed" + PIPELINE_START = "pipeline_start" + PIPELINE_COMPLETE = "pipeline_complete" + + +@dataclass +class PipelineEvent: + """Event emitted during pipeline execution.""" + + event_type: EventType + video_id: str + title: str | None = None + chapter_number: int | None = None + total_chapters: int | None = None + error: str | None = None + output_path: Path | None = None + + +@dataclass +class PipelineResult: + """Result of pipeline execution.""" + + success_count: int + failure_count: int + total_count: int + video_ids: list[str] + errors: dict[str, str] # video_id -> error message diff --git a/src/yt_study/llm/__init__.py b/src/yt_study/core/llm/__init__.py similarity index 100% rename from src/yt_study/llm/__init__.py rename to src/yt_study/core/llm/__init__.py diff --git a/src/yt_study/llm/generator.py b/src/yt_study/core/llm/generator.py similarity index 71% rename from src/yt_study/llm/generator.py rename to src/yt_study/core/llm/generator.py index 0dfee2c..829e870 100644 --- a/src/yt_study/llm/generator.py +++ b/src/yt_study/core/llm/generator.py @@ -1,10 +1,9 @@ """Study material generator with chunking and combining logic.""" import logging +from collections.abc import Callable from litellm import token_counter -from rich.console import Console -from rich.progress import Progress, TaskID from ..config import config from ..prompts.chapter_notes import ( @@ -23,7 +22,6 @@ # Re-use system prompt for now CHAPTER_SYSTEM_PROMPT = SYSTEM_PROMPT -console = Console() logger = logging.getLogger(__name__) @@ -54,8 +52,6 @@ def __init__( def _count_tokens(self, text: str) -> int: """Count tokens in text using model-specific tokenizer.""" - # Note: token_counter might do network calls for some models or use - # local libraries (tiktoken). For efficiency, we assume it's fast. try: count = token_counter(model=self.provider.model, text=text) return int(count) if count is not None else len(text) // 4 @@ -110,8 +106,6 @@ def _chunk_transcript(self, transcript: str) -> list[str]: continue # Re-add delimiter for estimation (approximate) - # We assume '. ' was the delimiter for simplicity, logic holds - # for others mostly as we care about token count term = sentence + ". " term_tokens = self._count_tokens(term) @@ -124,7 +118,6 @@ def _chunk_transcript(self, transcript: str) -> list[str]: current_tokens = 0 # 2. Hard split the massive segment - # Estimate char limit based on token size (conservative 3 chars/token) char_limit = config.chunk_size * 3 for i in range(0, len(sentence), char_limit): sub_part = sentence[i : i + char_limit] @@ -141,7 +134,6 @@ def _chunk_transcript(self, transcript: str) -> list[str]: overlap_chunk: list[str] = [] overlap_tokens = 0 - # Take sentences from the end of current_chunk until overlap limit for prev_sent in reversed(current_chunk): prev_tokens = self._count_tokens(prev_sent) if overlap_tokens + prev_tokens <= config.chunk_overlap: @@ -153,7 +145,6 @@ def _chunk_transcript(self, transcript: str) -> list[str]: current_chunk = overlap_chunk + [sentence] current_tokens = self._count_tokens(" ".join(current_chunk)) else: - # Should be unreachable due to check above, but safe fallback current_chunk.append(sentence) current_tokens += term_tokens else: @@ -167,31 +158,11 @@ def _chunk_transcript(self, transcript: str) -> list[str]: logger.info(f"Created {len(chunks)} chunks") return chunks - def _update_status( - self, - progress: Progress | None, - task_id: TaskID | None, - video_title: str, - message: str, - ) -> None: - """Safe helper to update progress bar or log message.""" - if progress and task_id is not None: - short_title = ( - (video_title[:20] + "...") if len(video_title) > 20 else video_title - ) - # We assume the layout uses 'description' for the status text - progress.update( - task_id, description=f"[yellow]{short_title}[/yellow]: {message}" - ) - else: - logger.info(f"{video_title}: {message}") - async def generate_study_notes( self, transcript: str, video_title: str = "Video", - progress: Progress | None = None, - task_id: TaskID | None = None, + on_status: Callable[[str], None] | None = None, ) -> str: """ Generate study notes from transcript. @@ -199,17 +170,22 @@ async def generate_study_notes( Args: transcript: Full video transcript text. video_title: Video title for progress display. - progress: Optional existing progress bar instance. - task_id: Optional task ID for updating progress. + on_status: Optional callback for status updates. + Signature: (status_message: str) -> None Returns: Complete study notes in Markdown format. """ chunks = self._chunk_transcript(transcript) + def _emit(msg: str) -> None: + logger.info(f"{video_title}: {msg}") + if on_status: + on_status(msg) + # Single chunk - generate directly if len(chunks) == 1: - self._update_status(progress, task_id, video_title, "Generating notes...") + _emit("Generating notes...") notes = await self.provider.generate( system_prompt=SYSTEM_PROMPT, @@ -218,23 +194,15 @@ async def generate_study_notes( max_tokens=self.max_tokens, ) - if not progress: - logger.info(f"Generated notes for {video_title}") return notes # Multiple chunks - generate for each, then combine - self._update_status( - progress, - task_id, - video_title, - f"Generating notes for {len(chunks)} chunks...", - ) + _emit(f"Generating notes for {len(chunks)} chunks...") chunk_notes = [] for i, chunk in enumerate(chunks, 1): - msg = f"Chunk {i}/{len(chunks)} (Generating)" - self._update_status(progress, task_id, video_title, msg) + _emit(f"Chunk {i}/{len(chunks)} (Generating)") note = await self.provider.generate( system_prompt=SYSTEM_PROMPT, @@ -244,12 +212,7 @@ async def generate_study_notes( ) chunk_notes.append(note) - self._update_status( - progress, - task_id, - video_title, - f"Combining {len(chunk_notes)} chunk notes...", - ) + _emit(f"Combining {len(chunk_notes)} chunk notes...") final_notes = await self.provider.generate( system_prompt=SYSTEM_PROMPT, @@ -258,9 +221,6 @@ async def generate_study_notes( max_tokens=self.max_tokens, ) - if not progress: - logger.info(f"Completed notes for {video_title}") - return final_notes async def generate_single_chapter_notes( @@ -290,8 +250,7 @@ async def generate_chapter_based_notes( self, chapter_transcripts: dict[str, str], video_title: str = "Video", - progress: Progress | None = None, - task_id: TaskID | None = None, + on_status: Callable[[str], None] | None = None, ) -> str: """ Generate study notes using chapter-based approach. @@ -299,22 +258,18 @@ async def generate_chapter_based_notes( Args: chapter_transcripts: Dictionary mapping chapter titles to transcript text. video_title: Video title for display. - progress: Optional existing progress bar instance. - task_id: Optional task ID for updating progress. + on_status: Optional callback for status updates. Returns: Complete study notes organized by chapters. """ - # Imports are already at top-level or can be moved up, but let's - # fix the specific issue. Previously we did lazy import inside - # function which caused issues - - self._update_status( - progress, - task_id, - video_title, - f"Generating notes for {len(chapter_transcripts)} chapters...", - ) + + def _emit(msg: str) -> None: + logger.info(f"{video_title}: {msg}") + if on_status: + on_status(msg) + + _emit(f"Generating notes for {len(chapter_transcripts)} chapters...") chapter_notes = {} total_chapters = len(chapter_transcripts) @@ -322,13 +277,7 @@ async def generate_chapter_based_notes( for i, (chapter_title, chapter_text) in enumerate( chapter_transcripts.items(), 1 ): - msg = f"Chapter {i}/{total_chapters}: {chapter_title[:20]}..." - self._update_status(progress, task_id, video_title, msg) - - # If a chapter is huge, we might need recursive chunking here too. - # For now, we assume chapters are reasonably sized or the model - # can handle ~100k context. Future improvement: Check token - # count of chapter_text and recurse if needed. + _emit(f"Chapter {i}/{total_chapters}: {chapter_title[:20]}...") notes = await self.provider.generate( system_prompt=CHAPTER_SYSTEM_PROMPT, @@ -338,9 +287,7 @@ async def generate_chapter_based_notes( ) chapter_notes[chapter_title] = notes - self._update_status( - progress, task_id, video_title, "Combining chapter notes..." - ) + _emit("Combining chapter notes...") final_notes = await self.provider.generate( system_prompt=CHAPTER_SYSTEM_PROMPT, @@ -349,7 +296,4 @@ async def generate_chapter_based_notes( max_tokens=self.max_tokens, ) - if not progress: - logger.info(f"Completed chapter-based notes for {video_title}") - return final_notes diff --git a/src/yt_study/llm/providers.py b/src/yt_study/core/llm/providers.py similarity index 98% rename from src/yt_study/llm/providers.py rename to src/yt_study/core/llm/providers.py index 625e45d..e17975a 100644 --- a/src/yt_study/llm/providers.py +++ b/src/yt_study/core/llm/providers.py @@ -5,12 +5,10 @@ from typing import Any from litellm import acompletion -from rich.console import Console from ..config import config -console = Console() logger = logging.getLogger(__name__) diff --git a/src/yt_study/core/pipeline.py b/src/yt_study/core/pipeline.py new file mode 100644 index 0000000..68ef63c --- /dev/null +++ b/src/yt_study/core/pipeline.py @@ -0,0 +1,397 @@ +""" +Core pipeline orchestrator with concurrent processing. + +This module provides the single entry point for the pipeline. +All UI concerns are handled externally via event callbacks. +No Rich, no Console, no Dashboard imports here. +""" + +import asyncio +import logging +import re +from collections.abc import Callable +from pathlib import Path + +from .config import config +from .events import EventType, PipelineEvent, PipelineResult +from .llm.generator import StudyMaterialGenerator +from .llm.providers import get_provider +from .youtube.metadata import ( + get_video_chapters, + get_video_duration, + get_video_title, +) +from .youtube.transcript import ( + YouTubeIPBlockError, + fetch_transcript, + split_transcript_by_chapters, +) + + +logger = logging.getLogger(__name__) + + +def sanitize_filename(name: str) -> str: + """ + Sanitize a string to be used as a filename. + + Args: + name: Raw filename string. + + Returns: + Sanitized string safe for file systems. + """ + name = re.sub(r'[<>:"/\\|?*]', "", name) + name = re.sub(r"\s+", " ", name) + name = name.strip()[:100] + + # Prevent directory traversal + if name in {".", ".."}: + return "untitled" + + return name if name else "untitled" + + +class CorePipeline: + """ + Core pipeline orchestrator. + + Pure business logic - no UI concerns. + Single public entry point: `run()` + + Communicates progress via event callbacks. + """ + + def __init__( + self, + model: str = "gemini/gemini-2.0-flash", + output_dir: Path | None = None, + languages: list[str] | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + ): + """ + Initialize the core pipeline. + + Args: + model: LLM model string. + output_dir: Output directory path. + languages: Preferred transcript languages. + temperature: LLM temperature. + max_tokens: Max tokens for generation. + """ + self.model = model + self.output_dir = output_dir or config.default_output_dir + self.languages = languages or config.default_languages + self.temperature = ( + temperature if temperature is not None else config.temperature + ) + self.max_tokens = max_tokens if max_tokens is not None else config.max_tokens + + self.provider = get_provider(model) + self.generator = StudyMaterialGenerator( + self.provider, + temperature=self.temperature, + max_tokens=self.max_tokens, + ) + self.semaphore = asyncio.Semaphore(config.max_concurrent_videos) + self.errors: dict[str, str] = {} + + def _check_api_key(self) -> bool: + """ + Check if API key is configured. + + Returns: + True if valid, False otherwise. + Errors are logged but not printed (UI's responsibility). + """ + key_name = config.get_api_key_name_for_model(self.model) + + if key_name: + import os + + if not os.environ.get(key_name): + logger.error(f"Missing API Key for {self.model}. Expected: {key_name}") + return False + return True + + async def _process_single_video( + self, + video_id: str, + output_path: Path, + on_event: Callable[[PipelineEvent], None] | None = None, + ) -> bool: + """ + Process a single video: fetch transcript and generate study notes. + + This is an INTERNAL method (async worker). + It emits events that the CLI/UI can listen to. + + Args: + video_id: YouTube Video ID. + output_path: Destination path for the MD file. + on_event: Callback for progress events. + + Returns: + True on success, False on failure. + """ + async with self.semaphore: + try: + # --- Metadata Phase --- + emit = self._emit_event(on_event) + emit(EventType.METADATA_START, video_id) + + # Fetch metadata concurrently + title = await asyncio.to_thread(get_video_title, video_id) + duration, chapters = await asyncio.gather( + asyncio.to_thread(get_video_duration, video_id), + asyncio.to_thread(get_video_chapters, video_id), + ) + + emit( + EventType.METADATA_FETCHED, + video_id, + title=title, + chapter_number=len(chapters) if chapters else 0, + ) + + # --- Transcript Phase --- + emit(EventType.TRANSCRIPT_FETCHING, video_id, title=title) + + transcript_obj = await fetch_transcript(video_id, self.languages) + + emit(EventType.TRANSCRIPT_FETCHED, video_id, title=title) + + # --- Generation Strategy --- + use_chapters = duration > 3600 and len(chapters) > 0 + + if use_chapters: + # Chapter-based generation + chapter_transcripts = split_transcript_by_chapters( + transcript_obj, chapters + ) + + safe_title = sanitize_filename(title) + output_folder = self.output_dir / safe_title + await asyncio.to_thread( + lambda: output_folder.mkdir(parents=True, exist_ok=True) + ) + + total_chapters = len(chapter_transcripts) + + for i, (chap_title, chap_text) in enumerate( + chapter_transcripts.items(), 1 + ): + emit( + EventType.CHAPTER_GENERATING, + video_id, + title=title, + chapter_number=i, + total_chapters=total_chapters, + ) + + notes = await self.generator.generate_single_chapter_notes( + chapter_title=chap_title, + chapter_text=chap_text, + ) + + safe_chapter = sanitize_filename(chap_title) + chapter_file = output_folder / f"{i:02d}_{safe_chapter}.md" + await asyncio.to_thread( + chapter_file.write_text, notes, encoding="utf-8" + ) + + emit( + EventType.GENERATION_COMPLETE, + video_id, + title=title, + output_path=output_folder, + ) + emit(EventType.VIDEO_SUCCESS, video_id, title=title) + return True + + else: + # Single file generation + emit(EventType.GENERATION_START, video_id, title=title) + + transcript_text = transcript_obj.to_text() + notes = await self.generator.generate_study_notes( + transcript_text, + video_title=title, + ) + + await asyncio.to_thread( + lambda: output_path.parent.mkdir(parents=True, exist_ok=True) + ) + await asyncio.to_thread( + output_path.write_text, notes, encoding="utf-8" + ) + + emit( + EventType.GENERATION_COMPLETE, + video_id, + title=title, + output_path=output_path, + ) + emit(EventType.VIDEO_SUCCESS, video_id, title=title) + return True + + except YouTubeIPBlockError as e: + error_msg = "YouTube IP blocked - use VPN or wait 1 hour" + logger.error(f"IP Block for {video_id}: {e}") + self.errors[video_id] = error_msg + emit(EventType.VIDEO_FAILED, video_id, error=error_msg) + return False + + except Exception as e: + error_msg = f"{type(e).__name__}: {str(e)}" + logger.error(f"Failed to process {video_id}: {e}", exc_info=True) + self.errors[video_id] = error_msg + emit(EventType.VIDEO_FAILED, video_id, error=error_msg) + return False + + def _emit_event( + self, + on_event: Callable[[PipelineEvent], None] | None, + ) -> Callable[..., None]: + """ + Create a helper function to emit events. + + This allows cleaner event emission throughout the code. + """ + + def emit( + event_type: EventType, + video_id: str, + title: str | None = None, + chapter_number: int | None = None, + total_chapters: int | None = None, + error: str | None = None, + output_path: Path | None = None, + ) -> None: + if on_event: + event = PipelineEvent( + event_type=event_type, + video_id=video_id, + title=title, + chapter_number=chapter_number, + total_chapters=total_chapters, + error=error, + output_path=output_path, + ) + try: + on_event(event) + except Exception as e: + logger.warning(f"Event handler error: {e}") + + return emit + + async def run( + self, + video_ids: list[str], + on_event: Callable[[PipelineEvent], None] | None = None, + ) -> PipelineResult: + """ + Process a list of video IDs concurrently. + + Args: + video_ids: List of YouTube video IDs to process. + on_event: Optional callback for progress events. + Signature: (event: PipelineEvent) -> None + + Returns: + PipelineResult with success count, failures, and detailed errors. + + Example (CLI usage): + >>> pipeline = CorePipeline(model="gemini-1.5-flash") + >>> + >>> def on_progress(event): + ... if event.event_type == EventType.VIDEO_SUCCESS: + ... print(f"✓ {event.title}") + ... elif event.event_type == EventType.VIDEO_FAILED: + ... print(f"✗ {event.title}: {event.error}") + >>> + >>> result = await pipeline.run( + ... ["VIDEO_ID_1", "VIDEO_ID_2"], + ... on_event=on_progress + ... ) + >>> print(f"Completed: {result.success_count}/{result.total_count}") + """ + # --- Validation --- + if not self._check_api_key(): + return PipelineResult( + success_count=0, + failure_count=len(video_ids), + total_count=len(video_ids), + video_ids=video_ids, + errors={vid: "Missing API key" for vid in video_ids}, + ) + + if not video_ids: + return PipelineResult( + success_count=0, + failure_count=0, + total_count=0, + video_ids=[], + errors={}, + ) + + emit = self._emit_event(on_event) + emit(EventType.PIPELINE_START, video_ids[0]) + + self.errors.clear() + success_count = 0 + + # --- Process all videos concurrently --- + tasks = [] + for video_id in video_ids: + safe_title = sanitize_filename(video_id) + output_path = self.output_dir / f"{safe_title}.md" + + task = self._process_single_video( + video_id, + output_path, + on_event=on_event, + ) + tasks.append(task) + + # Gather results + results = await asyncio.gather(*tasks, return_exceptions=False) + success_count = sum(1 for r in results if r is True) + + # --- Return Structured Result --- + failure_count = len(video_ids) - success_count + + emit(EventType.PIPELINE_COMPLETE, video_ids[0]) + + return PipelineResult( + success_count=success_count, + failure_count=failure_count, + total_count=len(video_ids), + video_ids=video_ids, + errors=self.errors, + ) + + +async def run_pipeline( + video_ids: list[str], + output_dir: Path | None = None, + model: str = "gemini/gemini-2.0-flash", + on_event: Callable[[PipelineEvent], None] | None = None, +) -> PipelineResult: + """ + Convenience function for simple usage. + + Alternative to CorePipeline class. + + Args: + video_ids: List of YouTube video IDs to process. + output_dir: Optional output directory. + model: LLM model string. + on_event: Optional callback for progress events. + + Returns: + PipelineResult with success/failure counts. + """ + pipeline = CorePipeline(model=model, output_dir=output_dir) + return await pipeline.run(video_ids, on_event=on_event) diff --git a/src/yt_study/prompts/__init__.py b/src/yt_study/core/prompts/__init__.py similarity index 100% rename from src/yt_study/prompts/__init__.py rename to src/yt_study/core/prompts/__init__.py diff --git a/src/yt_study/prompts/chapter_notes.py b/src/yt_study/core/prompts/chapter_notes.py similarity index 100% rename from src/yt_study/prompts/chapter_notes.py rename to src/yt_study/core/prompts/chapter_notes.py diff --git a/src/yt_study/prompts/study_notes.py b/src/yt_study/core/prompts/study_notes.py similarity index 100% rename from src/yt_study/prompts/study_notes.py rename to src/yt_study/core/prompts/study_notes.py diff --git a/src/yt_study/youtube/__init__.py b/src/yt_study/core/youtube/__init__.py similarity index 100% rename from src/yt_study/youtube/__init__.py rename to src/yt_study/core/youtube/__init__.py diff --git a/src/yt_study/youtube/metadata.py b/src/yt_study/core/youtube/metadata.py similarity index 98% rename from src/yt_study/youtube/metadata.py rename to src/yt_study/core/youtube/metadata.py index 14f2bca..c0cd834 100644 --- a/src/yt_study/youtube/metadata.py +++ b/src/yt_study/core/youtube/metadata.py @@ -5,10 +5,8 @@ from typing import Any from pytubefix import Playlist, YouTube -from rich.console import Console -console = Console() logger = logging.getLogger(__name__) diff --git a/src/yt_study/youtube/parser.py b/src/yt_study/core/youtube/parser.py similarity index 100% rename from src/yt_study/youtube/parser.py rename to src/yt_study/core/youtube/parser.py diff --git a/src/yt_study/youtube/playlist.py b/src/yt_study/core/youtube/playlist.py similarity index 98% rename from src/yt_study/youtube/playlist.py rename to src/yt_study/core/youtube/playlist.py index 706867a..6f4b354 100644 --- a/src/yt_study/youtube/playlist.py +++ b/src/yt_study/core/youtube/playlist.py @@ -4,10 +4,8 @@ import logging from pytubefix import Playlist -from rich.console import Console -console = Console() logger = logging.getLogger(__name__) diff --git a/src/yt_study/youtube/transcript.py b/src/yt_study/core/youtube/transcript.py similarity index 99% rename from src/yt_study/youtube/transcript.py rename to src/yt_study/core/youtube/transcript.py index 1f8a2ac..b825d69 100644 --- a/src/yt_study/youtube/transcript.py +++ b/src/yt_study/core/youtube/transcript.py @@ -5,7 +5,6 @@ from dataclasses import dataclass from typing import Any -from rich.console import Console from youtube_transcript_api import YouTubeTranscriptApi from youtube_transcript_api._errors import ( IpBlocked, @@ -18,7 +17,6 @@ from .metadata import VideoChapter -console = Console() logger = logging.getLogger(__name__) diff --git a/src/yt_study/pipeline/__init__.py b/src/yt_study/pipeline/__init__.py deleted file mode 100644 index d4edac6..0000000 --- a/src/yt_study/pipeline/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Pipeline orchestration module.""" - -from .orchestrator import PipelineOrchestrator - - -__all__ = ["PipelineOrchestrator"] diff --git a/src/yt_study/pipeline/orchestrator.py b/src/yt_study/pipeline/orchestrator.py deleted file mode 100644 index c183699..0000000 --- a/src/yt_study/pipeline/orchestrator.py +++ /dev/null @@ -1,534 +0,0 @@ -"""Main pipeline orchestrator with concurrent processing.""" - -import asyncio -import logging -import re -from pathlib import Path - -from rich.console import Console -from rich.live import Live -from rich.panel import Panel -from rich.progress import Progress, TaskID -from rich.table import Table - -from ..config import config -from ..llm.generator import StudyMaterialGenerator -from ..llm.providers import get_provider -from ..ui.dashboard import PipelineDashboard -from ..youtube.metadata import ( - get_playlist_info, - get_video_chapters, - get_video_duration, - get_video_title, -) -from ..youtube.parser import parse_youtube_url -from ..youtube.playlist import extract_playlist_videos -from ..youtube.transcript import ( - YouTubeIPBlockError, - fetch_transcript, - split_transcript_by_chapters, -) - - -console = Console() -logger = logging.getLogger(__name__) - - -def sanitize_filename(name: str) -> str: - """ - Sanitize a string to be used as a filename. - - Args: - name: Raw filename string. - - Returns: - Sanitized string safe for file systems. - """ - # Remove or replace invalid characters - name = re.sub(r'[<>:"/\\|?*]', "", name) - # Replace multiple spaces with single space - name = re.sub(r"\s+", " ", name) - # Trim and limit length - name = name.strip()[:100] - return name if name else "untitled" - - -class PipelineOrchestrator: - """ - Orchestrates the end-to-end pipeline for video processing. - - Manages concurrency, error handling, and UI updates. - """ - - def __init__( - self, - model: str = "gemini/gemini-2.0-flash", - output_dir: Path | None = None, - languages: list[str] | None = None, - temperature: float | None = None, - max_tokens: int | None = None, - ): - """ - Initialize orchestrator. - - Args: - model: LLM model string. - output_dir: Output directory path. - languages: Preferred transcript languages. - temperature: LLM temperature (defaults to config.temperature). - max_tokens: Max tokens (defaults to config.max_tokens). - """ - self.model = model - self.output_dir = output_dir or config.default_output_dir - self.languages = languages or config.default_languages - self.temperature = ( - temperature if temperature is not None else config.temperature - ) - self.max_tokens = max_tokens if max_tokens is not None else config.max_tokens - self.provider = get_provider(model) - self.generator = StudyMaterialGenerator( - self.provider, - temperature=self.temperature, - max_tokens=self.max_tokens, - ) - self.semaphore = asyncio.Semaphore(config.max_concurrent_videos) - - def validate_provider(self) -> bool: - """ - Validate that the API key for the selected provider is set. - - Returns: - True if valid (or warning logged), False if critical missing config. - """ - key_name = config.get_api_key_name_for_model(self.model) - - if key_name: - import os - - if not os.environ.get(key_name): - console.print( - f"\n[red bold]✗ Missing API Key for {self.model}[/red bold]" - ) - console.print( - f"[yellow]Expected environment variable: {key_name}[/yellow]" - ) - console.print( - "[dim]Please check your .env file or run:[/dim] " - "[cyan]yt-study setup[/cyan]\n" - ) - return False - - return True - - async def process_video( - self, - video_id: str, - output_path: Path, - progress: Progress | None = None, - task_id: TaskID | None = None, - video_title: str | None = None, - is_playlist: bool = False, - ) -> bool: - """ - Process a single video: fetch transcript and generate study notes. - - Args: - video_id: YouTube Video ID. - output_path: Destination path for the MD file. - progress: Rich Progress instance. - task_id: Rich TaskID. - video_title: Pre-fetched title (optional). - is_playlist: Whether this is part of a playlist (affects UI logging). - - Returns: - True on success, False on failure. - """ - async with self.semaphore: - local_task_id = task_id - - # If standalone (not part of worker pool), create a specific - # bar if requested - if is_playlist and progress and task_id is None: - display_title = (video_title or video_id)[:30] - local_task_id = progress.add_task( - description=f"[cyan]⏳ {display_title}... (Waiting)[/cyan]", - total=None, - ) - - try: - # 1. Fetch Metadata - if not video_title: - # Run in thread to avoid blocking - video_title = await asyncio.to_thread(get_video_title, video_id) - - # Fetch duration and chapters concurrently - duration, chapters = await asyncio.gather( - asyncio.to_thread(get_video_duration, video_id), - asyncio.to_thread(get_video_chapters, video_id), - ) - - title_display = (video_title or video_id)[:40] - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=f"[cyan]📥 {title_display}... (Transcript)[/cyan]", - ) - - # 2. Fetch Transcript - transcript_obj = await fetch_transcript(video_id, self.languages) - - # 3. Determine Generation Strategy - # Use chapters if video is long (>1h) and chapters exist - use_chapters = duration > 3600 and len(chapters) > 0 and not is_playlist - - if use_chapters: - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[cyan]📖 {title_display}... (Chapters)[/cyan]" - ), - ) - # else block removed as redundant - - # Split transcript - chapter_transcripts = split_transcript_by_chapters( - transcript_obj, chapters - ) - - # Create folder for chapter notes - safe_title = sanitize_filename(video_title) - output_folder = self.output_dir / safe_title - output_folder.mkdir(parents=True, exist_ok=True) - - # Generate chapter notes - # Fix: Iterate here and call generator for each chapter - # to save individually - - for i, (chap_title, chap_text) in enumerate( - chapter_transcripts.items(), 1 - ): - status_msg = f"Chapter {i}/{len(chapter_transcripts)}" - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[cyan]🤖 {title_display}... ({status_msg})[/cyan]" - ), - ) - - notes = await self.generator.generate_single_chapter_notes( - chapter_title=chap_title, - chapter_text=chap_text, - ) - - # Save individual chapter - safe_chapter = sanitize_filename(chap_title) - chapter_file = output_folder / f"{i:02d}_{safe_chapter}.md" - chapter_file.write_text(notes, encoding="utf-8") - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=f"[green]✓ {title_display} (Done)[/green]", - completed=True, - ) - - return True - - else: - # Single file generation - transcript_text = transcript_obj.to_text() - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[cyan]🤖 {title_display}... (Generating)[/cyan]" - ), - ) - - notes = await self.generator.generate_study_notes( - transcript_text, - video_title=title_display, - progress=progress, - task_id=local_task_id, - ) - - output_path.parent.mkdir(parents=True, exist_ok=True) - output_path.write_text(notes, encoding="utf-8") - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=f"[green]✓ {title_display} (Done)[/green]", - completed=True, - ) - - return True - - except Exception as e: - logger.error(f"Failed to process {video_id}: {e}") - - err_msg = str(e) - if isinstance(e, YouTubeIPBlockError) or ( - "blocking requests" in err_msg - ): - err_display = "[bold red]IP BLOCKED[/bold red]" - console.print( - Panel( - "[bold red]🚫 YouTube IP Block Detected[/bold red]\n\n" - "YouTube is limiting requests from your IP address.\n" - "[yellow]➤ Recommendation:[/yellow] Use a VPN or " - "wait ~1 hour.", - border_style="red", - ) - ) - else: - err_display = "(Failed)" - - if progress and local_task_id is not None: - progress.update( - local_task_id, - description=( - f"[red]✗ {(video_title or video_id)[:20]}... " - f"{err_display}[/red]" - ), - visible=True, - ) - - return False - - async def _process_with_dashboard( - self, - video_ids: list[str], - playlist_name: str = "Queue", - is_single_video: bool = False, - ) -> int: - """Process a list of videos using the Advanced Dashboard UI.""" - from ..ui.dashboard import PipelineDashboard - - # Initialize Dashboard FIRST to capture all output - # Adjust concurrency display: if total_videos < max_concurrency, - # only show needed workers - actual_concurrency = min(len(video_ids), config.max_concurrent_videos) - - dashboard = PipelineDashboard( - total_videos=len(video_ids), - concurrency=actual_concurrency, - playlist_name=playlist_name, - model_name=self.model, - ) - - success_count = 0 - video_titles = {} - - # Run Live Display (inline, not full screen) - # We start it immediately to show "Fetching metadata..." state - with Live(dashboard, refresh_per_second=10, console=console, screen=False): - # --- Phase 1: Metadata Fetching --- - TITLE_FETCH_CONCURRENCY = 10 - if not is_single_video: - dashboard.update_overall_status( - f"[cyan]📋 Fetching metadata for {len(video_ids)} videos...[/cyan]" - ) - - title_semaphore = asyncio.Semaphore(TITLE_FETCH_CONCURRENCY) - - async def fetch_title_safe(vid: str) -> str: - async with title_semaphore: - try: - return await asyncio.to_thread(get_video_title, vid) - except Exception: - return vid - - # Fetch titles - titles = await asyncio.gather(*(fetch_title_safe(vid) for vid in video_ids)) - video_titles = dict(zip(video_ids, titles, strict=True)) - - # --- Phase 2: Processing --- - if not is_single_video: - dashboard.update_overall_status("[bold blue]Total Progress[/bold blue]") - - # Determine base output folder - if is_single_video: - base_folder = self.output_dir - else: - base_folder = self.output_dir / sanitize_filename(playlist_name) - base_folder.mkdir(parents=True, exist_ok=True) - - # Worker Queue Implementation - queue: asyncio.Queue[str] = asyncio.Queue() - for vid in video_ids: - queue.put_nowait(vid) - - async def worker(worker_idx: int, task_id: TaskID) -> None: - nonlocal success_count - while not queue.empty(): - try: - video_id = queue.get_nowait() - except asyncio.QueueEmpty: - break - - title = video_titles.get(video_id, video_id) - safe_title = sanitize_filename(title) - - if is_single_video: - video_folder = base_folder / safe_title - output_path = video_folder / f"{safe_title}.md" - else: - output_path = base_folder / f"{safe_title}.md" - - # Update status - dashboard.update_worker( - worker_idx, f"[yellow]{title[:30]}...[/yellow]" - ) - - try: - result = await self.process_video( - video_id, - output_path, - progress=dashboard.worker_progress, - task_id=task_id, - video_title=title, - is_playlist=not is_single_video, - ) - - if result: - success_count += 1 - dashboard.add_completion(title) - else: - dashboard.add_failure(title) - - except Exception as e: - logger.error(f"Worker {worker_idx} failed on {video_id}: {e}") - dashboard.update_worker(worker_idx, f"[red]Error: {e}[/red]") - dashboard.add_failure(title) - await asyncio.sleep(2) - finally: - queue.task_done() - - # Worker done - dashboard.update_worker(worker_idx, "[dim]Idle[/dim]") - - try: - workers = [ - asyncio.create_task(worker(i, dashboard.worker_tasks[i])) - for i in range(actual_concurrency) - ] - await asyncio.gather(*workers) - except Exception as e: - logger.error(f"Dashboard execution failed: {e}") - - # Print summary table after dashboard closes - self._print_summary(dashboard) - - return success_count - - def _print_summary(self, dashboard: "PipelineDashboard") -> None: - """Print a summary table of the run.""" - if not dashboard.recent_completions and not dashboard.recent_failures: - return - - summary_table = Table( - title="📊 Processing Summary", - border_style="cyan", - show_header=True, - header_style="bold magenta", - ) - summary_table.add_column("Status", justify="center") - summary_table.add_column("Video Title", style="dim") - - # Add failures first (more important) - if dashboard.recent_failures: - for fail in dashboard.recent_failures: - summary_table.add_row("[bold red]FAILED[/bold red]", fail) - - # Add successes - if dashboard.recent_completions: - for comp in dashboard.recent_completions: - summary_table.add_row("[green]SUCCESS[/green]", comp) - - console.print("\n") - console.print(summary_table) - console.print( - f"\n[bold]Total Completed:[/bold] " - f"{dashboard.overall_progress.tasks[0].completed}/" - f"{dashboard.overall_progress.tasks[0].total}" - ) - console.print("[dim]Check logs for detailed error reports.[/dim]\n") - - async def process_playlist( - self, playlist_id: str, playlist_name: str = "playlist" - ) -> int: - """Process playlist with concurrent dynamic progress bars.""" - video_ids = await extract_playlist_videos(playlist_id) - return await self._process_with_dashboard(video_ids, playlist_name) - - async def run(self, url: str) -> None: - """ - Run the pipeline for a given YouTube URL. - - Args: - url: YouTube video or playlist URL. - """ - # Validate Provider Credentials - if not self.validate_provider(): - return - - try: - # Parse URL - parsed = parse_youtube_url(url) - - if parsed.url_type == "video": - if not parsed.video_id: - console.print("[red]Error: Video ID could not be extracted[/red]") - return - - await self._process_with_dashboard( - [parsed.video_id], - playlist_name="Single Video", - is_single_video=True, - ) - - # Summary is already printed by _process_with_dashboard - - elif parsed.url_type == "playlist": - if not parsed.playlist_id: - console.print( - "[red]Error: Playlist ID could not be extracted[/red]" - ) - return - - # Fetch basic playlist info first - handled in dashboard now - # if needed or kept minimal. Actually, playlist title fetching - # is useful to show BEFORE starting but _process_with_dashboard - # fetches metadata anyway. - # However, to pass playlist_name to dashboard, we might want it. - # But waiting for title can be slow. - # Let's let the dashboard handle titles for videos. - # For playlist title, we can try fast fetch or default to ID. - - # Fetching playlist title here is blocking/slow if not careful. - # Let's just use ID as name initially or fetch it quickly. - # The original code did fetch it. - - # To reduce redundancy, we remove the print statement - # "Playlist: ..." - playlist_title, _ = await asyncio.to_thread( - get_playlist_info, parsed.playlist_id - ) - - # Removed redundant print: - # console.print(f"[cyan]📑 Playlist:[/cyan] {playlist_title}\n") - - await self.process_playlist(parsed.playlist_id, playlist_title) - - # Summary handled by dashboard - - except ValueError as e: - console.print(f"[red]Input Error: {e}[/red]") - except Exception as e: - console.print(f"[red]Unexpected Error: {e}[/red]") - logger.exception("Pipeline run failed") diff --git a/src/yt_study/ui/presenter.py b/src/yt_study/ui/presenter.py new file mode 100644 index 0000000..ad96b64 --- /dev/null +++ b/src/yt_study/ui/presenter.py @@ -0,0 +1,287 @@ +""" +Rich TUI presenter for the pipeline. + +This module bridges CorePipeline (pure logic) with the Rich Dashboard (UI). +It handles URL parsing, playlist detection, and maps pipeline events to +dashboard updates. +""" + +import asyncio +import logging + +from rich.console import Console +from rich.live import Live +from rich.panel import Panel +from rich.table import Table + +from ..core.config import config +from ..core.events import EventType, PipelineEvent +from ..core.pipeline import CorePipeline, sanitize_filename +from ..core.youtube.metadata import get_playlist_info, get_video_title +from ..core.youtube.parser import parse_youtube_url +from ..core.youtube.playlist import extract_playlist_videos +from .dashboard import PipelineDashboard + + +console = Console() +logger = logging.getLogger(__name__) + + +class RichPipelinePresenter: + """ + Adapts CorePipeline events to a Rich Dashboard. + + Acts as the glue between business logic (CorePipeline) and + the terminal UI (PipelineDashboard + Rich Live). + """ + + def __init__( + self, + model: str = "gemini/gemini-2.0-flash", + output_dir=None, + languages=None, + temperature=None, + max_tokens=None, + ): + """ + Initialize presenter with a CorePipeline. + + Args: + model: LLM model string. + output_dir: Output directory path. + languages: Preferred transcript languages. + temperature: LLM temperature. + max_tokens: Max tokens for generation. + """ + self.model = model + self.pipeline = CorePipeline( + model=model, + output_dir=output_dir, + languages=languages, + temperature=temperature, + max_tokens=max_tokens, + ) + self.dashboard: PipelineDashboard | None = None + self.video_titles: dict[str, str] = {} + + def validate_provider(self) -> bool: + """ + Validate that the API key for the selected provider is set. + + Returns: + True if valid, False if missing. + """ + key_name = config.get_api_key_name_for_model(self.model) + + if key_name: + import os + + if not os.environ.get(key_name): + console.print( + f"\n[red bold]✗ Missing API Key for {self.model}[/red bold]" + ) + console.print( + f"[yellow]Expected environment variable: {key_name}[/yellow]" + ) + console.print( + "[dim]Please check your .env file or run:[/dim] " + "[cyan]yt-study setup[/cyan]\n" + ) + return False + + return True + + def _handle_event(self, event: PipelineEvent) -> None: + """Map CorePipeline events to Dashboard updates.""" + if not self.dashboard: + return + + if event.event_type == EventType.VIDEO_SUCCESS: + title = event.title or event.video_id + self.dashboard.add_completion(title) + + elif event.event_type == EventType.VIDEO_FAILED: + title = event.title or event.video_id + self.dashboard.add_failure(title) + + # Special handling for IP blocks + if event.error and "IP blocked" in event.error: + console.print( + Panel( + "[bold red]🚫 YouTube IP Block Detected[/bold red]\n\n" + "YouTube is limiting requests from your IP address.\n" + "[yellow]➤ Recommendation:[/yellow] Use a VPN or " + "wait ~1 hour.", + border_style="red", + ) + ) + + async def _process_with_dashboard( + self, + video_ids: list[str], + playlist_name: str = "Queue", + is_single_video: bool = False, + ) -> int: + """Process a list of videos using the Dashboard UI.""" + actual_concurrency = min(len(video_ids), config.max_concurrent_videos) + + self.dashboard = PipelineDashboard( + total_videos=len(video_ids), + concurrency=actual_concurrency, + playlist_name=playlist_name, + model_name=self.model, + ) + + # --- Phase 1: Metadata Fetching --- + with Live(self.dashboard, refresh_per_second=10, console=console, screen=False): + TITLE_FETCH_CONCURRENCY = 10 + if not is_single_video: + self.dashboard.update_overall_status( + f"[cyan]📋 Fetching metadata for {len(video_ids)} videos...[/cyan]" + ) + + title_semaphore = asyncio.Semaphore(TITLE_FETCH_CONCURRENCY) + + async def fetch_title_safe(vid: str) -> str: + async with title_semaphore: + try: + return await asyncio.to_thread(get_video_title, vid) + except Exception: + return vid + + titles = await asyncio.gather( + *(fetch_title_safe(vid) for vid in video_ids) + ) + self.video_titles = dict(zip(video_ids, titles, strict=True)) + + # --- Phase 2: Processing --- + if not is_single_video: + self.dashboard.update_overall_status( + "[bold blue]Total Progress[/bold blue]" + ) + + # Determine base output folder + if is_single_video: + base_folder = self.pipeline.output_dir + else: + base_folder = self.pipeline.output_dir / sanitize_filename( + playlist_name + ) + base_folder.mkdir(parents=True, exist_ok=True) + + # Override pipeline output_dir for this run + original_output = self.pipeline.output_dir + + if is_single_video: + # For single video, create a subfolder with the video title + for vid in video_ids: + title = self.video_titles.get(vid, vid) + safe_title = sanitize_filename(title) + video_folder = base_folder / safe_title + self.pipeline.output_dir = video_folder + else: + self.pipeline.output_dir = base_folder + + # Update worker display + for i, vid in enumerate(video_ids[:actual_concurrency]): + title = self.video_titles.get(vid, vid) + self.dashboard.update_worker( + i, f"[yellow]{title[:30]}...[/yellow]" + ) + + result = await self.pipeline.run(video_ids, on_event=self._handle_event) + + # Restore original output dir + self.pipeline.output_dir = original_output + + # Print summary table after dashboard closes + self._print_summary() + + return result.success_count + + def _print_summary(self) -> None: + """Print a summary table of the run.""" + if not self.dashboard: + return + + if ( + not self.dashboard.recent_completions + and not self.dashboard.recent_failures + ): + return + + summary_table = Table( + title="📊 Processing Summary", + border_style="cyan", + show_header=True, + header_style="bold magenta", + ) + summary_table.add_column("Status", justify="center") + summary_table.add_column("Video Title", style="dim") + + # Add failures first (more important) + if self.dashboard.recent_failures: + for fail in self.dashboard.recent_failures: + summary_table.add_row("[bold red]FAILED[/bold red]", fail) + + # Add successes + if self.dashboard.recent_completions: + for comp in self.dashboard.recent_completions: + summary_table.add_row("[green]SUCCESS[/green]", comp) + + console.print("\n") + console.print(summary_table) + console.print( + f"\n[bold]Total Completed:[/bold] " + f"{self.dashboard.overall_progress.tasks[0].completed}/" + f"{self.dashboard.overall_progress.tasks[0].total}" + ) + console.print("[dim]Check logs for detailed error reports.[/dim]\n") + + async def run(self, url: str) -> None: + """ + Run the pipeline for a given YouTube URL. + + Args: + url: YouTube video or playlist URL. + """ + # Validate Provider Credentials + if not self.validate_provider(): + return + + try: + # Parse URL + parsed = parse_youtube_url(url) + + if parsed.url_type == "video": + if not parsed.video_id: + console.print( + "[red]Error: Video ID could not be extracted[/red]" + ) + return + + await self._process_with_dashboard( + [parsed.video_id], + playlist_name="Single Video", + is_single_video=True, + ) + + elif parsed.url_type == "playlist": + if not parsed.playlist_id: + console.print( + "[red]Error: Playlist ID could not be extracted[/red]" + ) + return + + playlist_title, _ = await asyncio.to_thread( + get_playlist_info, parsed.playlist_id + ) + + video_ids = await extract_playlist_videos(parsed.playlist_id) + await self._process_with_dashboard(video_ids, playlist_title) + + except ValueError as e: + console.print(f"[red]Input Error: {e}[/red]") + except Exception as e: + console.print(f"[red]Unexpected Error: {e}[/red]") + logger.exception("Pipeline run failed") diff --git a/src/yt_study/setup_wizard.py b/src/yt_study/ui/wizard.py similarity index 100% rename from src/yt_study/setup_wizard.py rename to src/yt_study/ui/wizard.py diff --git a/tests/conftest.py b/tests/conftest.py index 45f70dc..0db75b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,7 +33,7 @@ def mock_config(monkeypatch): # Reload config to pick up env vars if necessary, # or just rely on Config loading from env. - from yt_study.config import config + from yt_study.core.config import config config.gemini_api_key = "dummy_gemini_key" config.openai_api_key = "dummy_openai_key" @@ -52,17 +52,17 @@ def mock_llm_provider(): @pytest.fixture def mock_transcript_api(mocker): """Mock YouTubeTranscriptApi class.""" - return mocker.patch("yt_study.youtube.transcript.YouTubeTranscriptApi") + return mocker.patch("yt_study.core.youtube.transcript.YouTubeTranscriptApi") @pytest.fixture def mock_pytube(mocker): """Mock pytubefix YouTube and Playlist classes.""" # We patch the classes where they are imported in metadata.py - mock_yt = mocker.patch("yt_study.youtube.metadata.YouTube") - mock_pl = mocker.patch("yt_study.youtube.metadata.Playlist") + mock_yt = mocker.patch("yt_study.core.youtube.metadata.YouTube") + mock_pl = mocker.patch("yt_study.core.youtube.metadata.Playlist") # Also patch in playlist.py if used there - mocker.patch("yt_study.youtube.playlist.Playlist", new=mock_pl) + mocker.patch("yt_study.core.youtube.playlist.Playlist", new=mock_pl) return mock_yt, mock_pl diff --git a/tests/test_cli.py b/tests/test_cli.py index 16ec627..286863e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,9 +12,9 @@ @pytest.fixture -def mock_orchestrator(): # noqa: ARG001 - # Patch where PipelineOrchestrator is defined - with patch("yt_study.pipeline.orchestrator.PipelineOrchestrator") as mock: +def mock_presenter(): # noqa: ARG001 + # Patch where RichPipelinePresenter is used in CLI + with patch("yt_study.ui.presenter.RichPipelinePresenter") as mock: instance = mock.return_value instance.run = AsyncMock() yield mock @@ -36,12 +36,6 @@ def test_version(): def test_version_import_error(): """Test version command handles missing __version__ gracefully.""" with patch.dict("sys.modules", {"yt_study": None}): - # Mocking import error for specific attribute is tricky with sys.modules - # simpler to patch the import statement inside cli.py if possible, - # or just assume the fallback logic works if __version__ is missing. - # Let's try patching builtins.__import__ specifically for that - # module? Too complex. - # Just manually call the function? No, tested via runner. pass @@ -61,15 +55,15 @@ def test_config_path_missing(): assert "No configuration found" in result.stdout -def test_process_url_success(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_url_success(mock_config_exists, mock_presenter): # noqa: ARG001 """Test processing a simple URL.""" result = runner.invoke(app, ["process", "https://youtube.com/watch?v=123"]) assert result.exit_code == 0 - mock_orchestrator.return_value.run.assert_awaited() + mock_presenter.return_value.run.assert_awaited() -def test_process_batch_file(mock_config_exists, mock_orchestrator, tmp_path): # noqa: ARG001 +def test_process_batch_file(mock_config_exists, mock_presenter, tmp_path): # noqa: ARG001 """Test processing a batch file.""" batch_file = tmp_path / "urls.txt" batch_file.write_text("https://yt.com/v1\nhttps://yt.com/v2") @@ -77,10 +71,10 @@ def test_process_batch_file(mock_config_exists, mock_orchestrator, tmp_path): # result = runner.invoke(app, ["process", str(batch_file)]) assert result.exit_code == 0 - assert mock_orchestrator.return_value.run.await_count == 2 + assert mock_presenter.return_value.run.await_count == 2 -def test_process_batch_file_empty(mock_config_exists, mock_orchestrator, tmp_path): # noqa: ARG001 +def test_process_batch_file_empty(mock_config_exists, mock_presenter, tmp_path): # noqa: ARG001 """Test processing an empty batch file.""" batch_file = tmp_path / "empty.txt" batch_file.write_text("") @@ -89,10 +83,10 @@ def test_process_batch_file_empty(mock_config_exists, mock_orchestrator, tmp_pat assert result.exit_code == 0 assert "Batch file is empty" in result.stdout - mock_orchestrator.return_value.run.assert_not_awaited() + mock_presenter.return_value.run.assert_not_awaited() -def test_process_batch_file_error(mock_config_exists, mock_orchestrator, tmp_path): # noqa: ARG001 +def test_process_batch_file_error(mock_config_exists, mock_presenter, tmp_path): # noqa: ARG001 """Test error reading batch file.""" batch_file = tmp_path / "restricted.txt" batch_file.touch() @@ -101,12 +95,7 @@ def test_process_batch_file_error(mock_config_exists, mock_orchestrator, tmp_pat with patch("pathlib.Path.read_text", side_effect=OSError("Access denied")): result = runner.invoke(app, ["process", str(batch_file)]) - assert ( - result.exit_code == 0 - ) # It returns early, exit code 0 usually unless exception propagates - # Wait, cli.py does return, so exit code 0 is correct for Typer - # unless we raise Exit. - # Checks stdout + assert result.exit_code == 0 assert "Error reading batch file" in result.stdout @@ -114,28 +103,25 @@ def test_process_missing_config(): """Test that missing config triggers setup check/error.""" with ( patch("yt_study.cli.check_config_exists", return_value=False), - patch("yt_study.setup_wizard.run_setup_wizard") as mock_setup, + patch("yt_study.ui.wizard.run_setup_wizard") as mock_setup, ): runner.invoke(app, ["process", "url"]) mock_setup.assert_called_once() -def test_process_keyboard_interrupt(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_keyboard_interrupt(mock_config_exists, mock_presenter): # noqa: ARG001 """Test handling of KeyboardInterrupt.""" - mock_orchestrator.return_value.run.side_effect = KeyboardInterrupt() + mock_presenter.return_value.run.side_effect = KeyboardInterrupt() result = runner.invoke(app, ["process", "url"]) assert result.exit_code == 1 - # Check for Rich Panel content format - # Rich markup uses symbols like ⚠ which might be encoded differently - # Let's match partial string content "Process interrupted" assert "Process interrupted by user" in result.stdout -def test_process_general_exception(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_general_exception(mock_config_exists, mock_presenter): # noqa: ARG001 """Test handling of general exceptions.""" - mock_orchestrator.return_value.run.side_effect = Exception("Boom") + mock_presenter.return_value.run.side_effect = Exception("Boom") result = runner.invoke(app, ["process", "url"]) @@ -146,7 +132,7 @@ def test_process_general_exception(mock_config_exists, mock_orchestrator): # no def test_setup_command(): """Test setup command triggers wizard.""" - with patch("yt_study.setup_wizard.run_setup_wizard") as mock_wizard: + with patch("yt_study.ui.wizard.run_setup_wizard") as mock_wizard: result = runner.invoke(app, ["setup"]) assert result.exit_code == 0 mock_wizard.assert_called_once() @@ -154,10 +140,7 @@ def test_setup_command(): def test_setup_import_error(): """Test setup command handling missing wizard module.""" - # Simulate ImportError when importing setup_wizard - with patch.dict("sys.modules", {"yt_study.setup_wizard": None}): - # This approach is tricky because we are inside the test process. - # Better to patch the specific import or function call if lazy. + with patch.dict("sys.modules", {"yt_study.ui.wizard": None}): pass @@ -165,9 +148,8 @@ def test_ensure_setup_import_error(): """Test ensure_setup handles missing wizard.""" with ( patch("yt_study.cli.check_config_exists", return_value=False), - patch.dict("sys.modules", {"yt_study.setup_wizard": None}), + patch.dict("sys.modules", {"yt_study.ui.wizard": None}), ): - # This won't work easily as the module is likely already imported. pass @@ -178,7 +160,7 @@ def test_callback_help(): assert "Usage" in result.stdout -def test_process_with_temperature_flag(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_with_temperature_flag(mock_config_exists, mock_presenter): # noqa: ARG001 """Test processing with custom temperature parameter.""" result = runner.invoke( app, @@ -186,12 +168,12 @@ def test_process_with_temperature_flag(mock_config_exists, mock_orchestrator): ) assert result.exit_code == 0 - call_kwargs = mock_orchestrator.call_args[1] + call_kwargs = mock_presenter.call_args[1] assert call_kwargs["temperature"] == 0.5 - mock_orchestrator.return_value.run.assert_awaited() + mock_presenter.return_value.run.assert_awaited() -def test_process_with_max_tokens_flag(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_with_max_tokens_flag(mock_config_exists, mock_presenter): # noqa: ARG001 """Test processing with custom max_tokens parameter.""" result = runner.invoke( app, @@ -199,12 +181,12 @@ def test_process_with_max_tokens_flag(mock_config_exists, mock_orchestrator): # ) assert result.exit_code == 0 - call_kwargs = mock_orchestrator.call_args[1] + call_kwargs = mock_presenter.call_args[1] assert call_kwargs["max_tokens"] == 2000 - mock_orchestrator.return_value.run.assert_awaited() + mock_presenter.return_value.run.assert_awaited() -def test_process_with_temperature_and_max_tokens(mock_config_exists, mock_orchestrator): # noqa: ARG001 +def test_process_with_temperature_and_max_tokens(mock_config_exists, mock_presenter): # noqa: ARG001 """Test processing with both temperature and max_tokens parameters.""" result = runner.invoke( app, @@ -219,7 +201,7 @@ def test_process_with_temperature_and_max_tokens(mock_config_exists, mock_orches ) assert result.exit_code == 0 - call_kwargs = mock_orchestrator.call_args[1] + call_kwargs = mock_presenter.call_args[1] assert call_kwargs["temperature"] == 0.8 assert call_kwargs["max_tokens"] == 3000 - mock_orchestrator.return_value.run.assert_awaited() + mock_presenter.return_value.run.assert_awaited() diff --git a/tests/test_config.py b/tests/test_config.py index c771c11..8e37a38 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,7 +3,7 @@ import os from unittest.mock import patch -from yt_study.config import Config +from yt_study.core.config import Config class TestConfig: diff --git a/tests/test_core_pipeline.py b/tests/test_core_pipeline.py new file mode 100644 index 0000000..51bdf23 --- /dev/null +++ b/tests/test_core_pipeline.py @@ -0,0 +1,214 @@ +"""Tests for the core pipeline.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from yt_study.core.events import EventType, PipelineResult +from yt_study.core.pipeline import CorePipeline, sanitize_filename + + +class TestSanitizeFilename: + """Test filename sanitization.""" + + def test_basic_name(self): + assert sanitize_filename("Hello World") == "Hello World" + + def test_special_characters(self): + assert sanitize_filename('foo/bar:baz"qux') == "foobarbazqux" + + def test_whitespace_normalization(self): + assert sanitize_filename(" spaces ") == "spaces" + + def test_length_limit(self): + assert len(sanitize_filename("a" * 200)) == 100 + + def test_empty_returns_untitled(self): + assert sanitize_filename("") == "untitled" + + def test_dot_traversal(self): + """Prevent directory traversal with '.' title.""" + assert sanitize_filename(".") == "untitled" + + def test_dotdot_traversal(self): + """Prevent directory traversal with '..' title.""" + assert sanitize_filename("..") == "untitled" + + def test_only_special_chars(self): + """All chars removed → returns untitled.""" + assert sanitize_filename('<>:"/\\|?*') == "untitled" + + +class TestCorePipelineInit: + """Test pipeline initialization.""" + + def test_defaults(self, temp_output_dir, mock_llm_provider): + with patch( + "yt_study.core.pipeline.get_provider", + return_value=mock_llm_provider, + ): + pipeline = CorePipeline( + model="mock-model", output_dir=temp_output_dir + ) + assert pipeline.model == "mock-model" + assert pipeline.output_dir == temp_output_dir + + def test_custom_params(self, temp_output_dir, mock_llm_provider): + with patch( + "yt_study.core.pipeline.get_provider", + return_value=mock_llm_provider, + ): + pipeline = CorePipeline( + model="mock-model", + output_dir=temp_output_dir, + temperature=0.5, + max_tokens=1000, + ) + assert pipeline.temperature == 0.5 + assert pipeline.max_tokens == 1000 + + +class TestCorePipelineRun: + """Test pipeline execution.""" + + @pytest.fixture + def pipeline(self, temp_output_dir, mock_llm_provider): + with patch( + "yt_study.core.pipeline.get_provider", + return_value=mock_llm_provider, + ): + p = CorePipeline(model="mock-model", output_dir=temp_output_dir) + p.generator = MagicMock() + p.generator.generate_study_notes = AsyncMock(return_value="# Notes") + p.generator.generate_single_chapter_notes = AsyncMock( + return_value="# Chapter Notes" + ) + return p + + @pytest.mark.asyncio + async def test_empty_list(self, pipeline): + """Empty video list returns zero counts.""" + result = await pipeline.run([]) + assert result.success_count == 0 + assert result.failure_count == 0 + assert result.total_count == 0 + + @pytest.mark.asyncio + async def test_success(self, pipeline): + """Successful processing emits VIDEO_SUCCESS and writes file.""" + events = [] + + with ( + patch( + "yt_study.core.pipeline.get_video_title", + return_value="Test Video", + ), + patch("yt_study.core.pipeline.get_video_duration", return_value=100), + patch("yt_study.core.pipeline.get_video_chapters", return_value=[]), + patch( + "yt_study.core.pipeline.fetch_transcript", + new_callable=AsyncMock, + ) as mock_fetch, + ): + mock_transcript = MagicMock() + mock_transcript.to_text.return_value = "Transcript text" + mock_fetch.return_value = mock_transcript + + result = await pipeline.run( + ["vid123"], + on_event=lambda e: events.append(e), + ) + + assert result.success_count == 1 + assert result.failure_count == 0 + + # Check events + event_types = [e.event_type for e in events] + assert EventType.PIPELINE_START in event_types + assert EventType.VIDEO_SUCCESS in event_types + assert EventType.PIPELINE_COMPLETE in event_types + + @pytest.mark.asyncio + async def test_failure_emits_event(self, pipeline): + """Exception during processing emits VIDEO_FAILED.""" + events = [] + + with ( + patch( + "yt_study.core.pipeline.get_video_title", + side_effect=Exception("Network error"), + ), + ): + result = await pipeline.run( + ["bad_vid"], + on_event=lambda e: events.append(e), + ) + + assert result.success_count == 0 + assert result.failure_count == 1 + assert "bad_vid" in result.errors + + event_types = [e.event_type for e in events] + assert EventType.VIDEO_FAILED in event_types + + @pytest.mark.asyncio + async def test_missing_api_key(self, pipeline, monkeypatch): + """Missing API key fails all videos.""" + with patch( + "yt_study.core.config.config.get_api_key_name_for_model", + return_value="TEST_KEY", + ): + monkeypatch.delenv("TEST_KEY", raising=False) + result = await pipeline.run(["vid1", "vid2"]) + + assert result.success_count == 0 + assert result.failure_count == 2 + assert result.total_count == 2 + + @pytest.mark.asyncio + async def test_with_chapters(self, pipeline): + """Test processing a video with chapters.""" + with ( + patch( + "yt_study.core.pipeline.get_video_title", + return_value="Long Video", + ), + patch("yt_study.core.pipeline.get_video_duration", return_value=4000), + patch( + "yt_study.core.pipeline.get_video_chapters", + return_value=["chap1"], + ), + patch( + "yt_study.core.pipeline.fetch_transcript", + new_callable=AsyncMock, + ) as mock_fetch, + patch( + "yt_study.core.pipeline.split_transcript_by_chapters", + return_value={"Ch1": "text"}, + ), + ): + mock_transcript = MagicMock() + mock_fetch.return_value = mock_transcript + + result = await pipeline.run(["vid123"]) + + assert result.success_count == 1 + + # Verify folder creation + expected_folder = pipeline.output_dir / "Long Video" + assert expected_folder.exists() + assert (expected_folder / "01_Ch1.md").exists() + + @pytest.mark.asyncio + async def test_pipeline_result_structure(self, pipeline): + """Verify PipelineResult has correct structure.""" + with patch.object( + pipeline, "_process_single_video", new_callable=AsyncMock + ) as mock_proc: + mock_proc.return_value = True + + result = await pipeline.run(["v1", "v2"]) + + assert isinstance(result, PipelineResult) + assert result.total_count == 2 + assert result.video_ids == ["v1", "v2"] diff --git a/tests/test_llm/test_generator.py b/tests/test_llm/test_generator.py index f524a03..772263a 100644 --- a/tests/test_llm/test_generator.py +++ b/tests/test_llm/test_generator.py @@ -4,8 +4,8 @@ import pytest -from yt_study.config import config -from yt_study.llm.generator import StudyMaterialGenerator +from yt_study.core.config import config +from yt_study.core.llm.generator import StudyMaterialGenerator class TestStudyMaterialGenerator: @@ -18,14 +18,14 @@ def generator(self, mock_llm_provider): def test_count_tokens_fallback(self, generator): """Test token counting fallback when library fails.""" with patch( - "yt_study.llm.generator.token_counter", side_effect=Exception("Error") + "yt_study.core.llm.generator.token_counter", side_effect=Exception("Error") ): count = generator._count_tokens("1234") assert count == 1 # 4 chars // 4 = 1 def test_chunk_transcript_small(self, generator): """Test that small transcripts are not chunked.""" - with patch("yt_study.llm.generator.token_counter", return_value=100): + with patch("yt_study.core.llm.generator.token_counter", return_value=100): chunks = generator._chunk_transcript("Small text") assert len(chunks) == 1 assert chunks[0] == "Small text" @@ -36,7 +36,7 @@ def test_chunk_transcript_sentences(self, generator): config.chunk_size = 5 # Allow room for a sentence + delimiter try: - with patch("yt_study.llm.generator.token_counter") as mock_tc: + with patch("yt_study.core.llm.generator.token_counter") as mock_tc: # 1 token per word, with the delimiter ". " adding extra tokens def count_tokens(_model, text): # noqa: ARG001 return len(text.split()) @@ -60,7 +60,7 @@ def test_chunk_transcript_newlines(self, generator): config.chunk_size = 2 try: - with patch("yt_study.llm.generator.token_counter") as mock_tc: + with patch("yt_study.core.llm.generator.token_counter") as mock_tc: mock_tc.side_effect = lambda _model, text: len(text.split()) # noqa: ARG005 # No periods, just newlines @@ -78,7 +78,7 @@ def test_chunk_transcript_hard_split(self, generator): config.chunk_size = 1 # Tiny try: - with patch("yt_study.llm.generator.token_counter") as mock_tc: + with patch("yt_study.core.llm.generator.token_counter") as mock_tc: # Mock token counter to say everything is too big mock_tc.side_effect = lambda _model, text: len(text) # noqa: ARG005 diff --git a/tests/test_llm/test_providers.py b/tests/test_llm/test_providers.py index b07e6a9..b77f5e1 100644 --- a/tests/test_llm/test_providers.py +++ b/tests/test_llm/test_providers.py @@ -4,7 +4,7 @@ import pytest -from yt_study.llm.providers import LLMGenerationError, LLMProvider, get_provider +from yt_study.core.llm.providers import LLMGenerationError, LLMProvider, get_provider class TestLLMProvider: @@ -20,7 +20,7 @@ def test_init_validation(self, mock_config): # noqa: ARG002 @pytest.mark.asyncio async def test_generate_success(self): """Test successful generation.""" - with patch("yt_study.llm.providers.acompletion") as mock_acompletion: + with patch("yt_study.core.llm.providers.acompletion") as mock_acompletion: # Setup mock response mock_response = MagicMock() mock_response.choices[0].message.content = "Generated content" @@ -43,7 +43,7 @@ async def test_generate_success(self): @pytest.mark.asyncio async def test_generate_cleanup_markdown(self): """Test cleaning of markdown code blocks from response.""" - with patch("yt_study.llm.providers.acompletion") as mock_acompletion: + with patch("yt_study.core.llm.providers.acompletion") as mock_acompletion: mock_response = MagicMock() # LLM returns content wrapped in ```markdown ... ``` mock_response.choices[ @@ -59,7 +59,7 @@ async def test_generate_cleanup_markdown(self): @pytest.mark.asyncio async def test_generate_failure(self): """Test generation failure raises custom exception.""" - with patch("yt_study.llm.providers.acompletion") as mock_acompletion: + with patch("yt_study.core.llm.providers.acompletion") as mock_acompletion: mock_acompletion.side_effect = Exception("API Error") provider = LLMProvider("gpt-4o") diff --git a/tests/test_pipeline/test_orchestrator.py b/tests/test_pipeline/test_orchestrator.py index ac71d01..bbf42c2 100644 --- a/tests/test_pipeline/test_orchestrator.py +++ b/tests/test_pipeline/test_orchestrator.py @@ -1,141 +1,61 @@ -"""Tests for pipeline orchestrator.""" +"""Tests for pipeline presenter (Rich UI wrapper).""" from unittest.mock import AsyncMock, MagicMock, patch import pytest -from yt_study.pipeline.orchestrator import PipelineOrchestrator, sanitize_filename +from yt_study.ui.presenter import RichPipelinePresenter -def test_sanitize_filename(): - """Test filename sanitization.""" - assert sanitize_filename("Hello World") == "Hello World" - assert sanitize_filename("foo/bar:baz") == "foobarbaz" - assert sanitize_filename(" spaces ") == "spaces" - assert len(sanitize_filename("a" * 200)) == 100 - - -class TestPipelineOrchestrator: - """Test orchestrator logic.""" +class TestRichPipelinePresenter: + """Test presenter logic.""" @pytest.fixture - def orchestrator(self, temp_output_dir, mock_llm_provider): + def presenter(self, temp_output_dir, mock_llm_provider): with patch( - "yt_study.pipeline.orchestrator.get_provider", + "yt_study.core.pipeline.get_provider", return_value=mock_llm_provider, ): - orch = PipelineOrchestrator(model="mock-model", output_dir=temp_output_dir) - # Mock the generator inside - orch.generator = MagicMock() - orch.generator.generate_study_notes = AsyncMock(return_value="# Notes") - orch.generator.generate_single_chapter_notes = AsyncMock( - return_value="# Chapter Notes" + pres = RichPipelinePresenter( + model="mock-model", output_dir=temp_output_dir ) - orch.generator.provider = ( - mock_llm_provider # needed for direct calls in chapter loop + # Mock the inner pipeline's generator + pres.pipeline.generator = MagicMock() + pres.pipeline.generator.generate_study_notes = AsyncMock( + return_value="# Notes" ) - return orch + pres.pipeline.generator.generate_single_chapter_notes = AsyncMock( + return_value="# Chapter Notes" + ) + return pres - def test_validate_provider_missing_key(self, orchestrator, monkeypatch): + def test_validate_provider_missing_key(self, presenter, monkeypatch): """Test validation fails if key is missing.""" - # Mock config to return key name but env var is empty with patch( - "yt_study.config.config.get_api_key_name_for_model", return_value="TEST_KEY" + "yt_study.core.config.config.get_api_key_name_for_model", + return_value="TEST_KEY", ): monkeypatch.delenv("TEST_KEY", raising=False) - assert orchestrator.validate_provider() is False + assert presenter.validate_provider() is False - def test_validate_provider_success(self, orchestrator, monkeypatch): + def test_validate_provider_success(self, presenter, monkeypatch): """Test validation succeeds if key exists.""" with patch( - "yt_study.config.config.get_api_key_name_for_model", return_value="TEST_KEY" + "yt_study.core.config.config.get_api_key_name_for_model", + return_value="TEST_KEY", ): monkeypatch.setenv("TEST_KEY", "123") - assert orchestrator.validate_provider() is True - - @pytest.mark.asyncio - async def test_process_video_single(self, orchestrator): - """Test processing a single video (no chapters).""" - # Mock dependencies - with ( - patch( - "yt_study.pipeline.orchestrator.get_video_title", - return_value="Test Video", - ), - patch( - "yt_study.pipeline.orchestrator.get_video_duration", return_value=100 - ), - patch("yt_study.pipeline.orchestrator.get_video_chapters", return_value=[]), - patch( - "yt_study.pipeline.orchestrator.fetch_transcript", - new_callable=AsyncMock, - ) as mock_fetch, - ): - mock_transcript = MagicMock() - mock_transcript.to_text.return_value = "Transcript text" - mock_fetch.return_value = mock_transcript - - video_id = "vid123" - output_path = orchestrator.output_dir / "notes.md" - - success = await orchestrator.process_video(video_id, output_path) - - assert success is True - assert output_path.exists() - assert output_path.read_text(encoding="utf-8") == "# Notes" - orchestrator.generator.generate_study_notes.assert_called_once() - - @pytest.mark.asyncio - async def test_process_video_with_chapters(self, orchestrator): - """Test processing a video with chapters.""" - # Mock dependencies - # Duration > 3600 (1h) + Chapters present - with ( - patch( - "yt_study.pipeline.orchestrator.get_video_title", - return_value="Long Video", - ), - patch( - "yt_study.pipeline.orchestrator.get_video_duration", return_value=4000 - ), - patch( - "yt_study.pipeline.orchestrator.get_video_chapters", - return_value=["chap1"], - ), - patch( - "yt_study.pipeline.orchestrator.fetch_transcript", - new_callable=AsyncMock, - ) as mock_fetch, - patch( - "yt_study.pipeline.orchestrator.split_transcript_by_chapters", - return_value={"Ch1": "text"}, - ), - ): - mock_transcript = MagicMock() - mock_fetch.return_value = mock_transcript - - video_id = "vid123" - output_path = ( - orchestrator.output_dir / "ignored.md" - ) # Folder structure used instead - - success = await orchestrator.process_video(video_id, output_path) - - assert success is True - # Verify folder creation - expected_folder = orchestrator.output_dir / "Long Video" - assert expected_folder.exists() - # Verify individual chapter file created (mock provider returns - # default text) - assert (expected_folder / "01_Ch1.md").exists() + assert presenter.validate_provider() is True @pytest.mark.asyncio - async def test_run_video_flow(self, orchestrator): + async def test_run_video_flow(self, presenter): """Test run() method flow for a video URL.""" with ( - patch("yt_study.pipeline.orchestrator.parse_youtube_url") as mock_parse, + patch( + "yt_study.ui.presenter.parse_youtube_url" + ) as mock_parse, patch.object( - orchestrator, "_process_with_dashboard", new_callable=AsyncMock + presenter, "_process_with_dashboard", new_callable=AsyncMock ) as mock_dash, ): mock_parsed = MagicMock() @@ -145,7 +65,7 @@ async def test_run_video_flow(self, orchestrator): mock_dash.return_value = 1 # 1 success - await orchestrator.run("http://url") + await presenter.run("http://url") mock_dash.assert_called_once() args = mock_dash.call_args diff --git a/tests/test_setup_wizard.py b/tests/test_setup_wizard.py index 8718549..f89f8fb 100644 --- a/tests/test_setup_wizard.py +++ b/tests/test_setup_wizard.py @@ -2,7 +2,7 @@ from unittest.mock import mock_open, patch -from yt_study.setup_wizard import ( +from yt_study.ui.wizard import ( get_api_key, get_available_models, load_config, @@ -58,10 +58,11 @@ def test_save_config(self): mock_path = Path("dummy_path") with ( patch( - "yt_study.setup_wizard.load_config", return_value={"OLD_KEY": "old_val"} + "yt_study.ui.wizard.load_config", + return_value={"OLD_KEY": "old_val"}, ), patch("pathlib.Path.open", mock_open()) as mock_file, - patch("yt_study.setup_wizard.get_config_path", return_value=mock_path), + patch("yt_study.ui.wizard.get_config_path", return_value=mock_path), ): new_config = {"NEW_KEY": "new_val", "DEFAULT_MODEL": "new_model"} save_config(new_config) @@ -151,7 +152,7 @@ def test_select_provider(self): } with ( - patch("yt_study.setup_wizard.PROVIDER_CONFIG", test_config), + patch("yt_study.ui.wizard.PROVIDER_CONFIG", test_config), patch("rich.prompt.Prompt.ask", return_value="2"), ): result = select_provider({"p1": [], "p2": []}) @@ -167,7 +168,7 @@ def test_select_model_pagination(self): inputs = ["n", "p", "1"] with ( - patch("yt_study.setup_wizard.PROVIDER_CONFIG", {"p1": {"name": "P1"}}), + patch("yt_study.ui.wizard.PROVIDER_CONFIG", {"p1": {"name": "P1"}}), patch("rich.prompt.Prompt.ask", side_effect=inputs), ): selected = select_model("p1", models) @@ -179,7 +180,8 @@ def test_select_model_gemini_prefix(self): with ( patch( - "yt_study.setup_wizard.PROVIDER_CONFIG", {"gemini": {"name": "Google"}} + "yt_study.ui.wizard.PROVIDER_CONFIG", + {"gemini": {"name": "Google"}}, ), patch("rich.prompt.Prompt.ask", return_value="1"), ): @@ -221,18 +223,19 @@ def test_run_setup_wizard_full_flow(self): """Test full setup flow.""" # Mocks with ( - patch("yt_study.setup_wizard.load_config", return_value={}), + patch("yt_study.ui.wizard.load_config", return_value={}), patch( - "yt_study.setup_wizard.get_available_models", + "yt_study.ui.wizard.get_available_models", return_value={"gemini": ["gemini-pro"]}, ), - patch("yt_study.setup_wizard.select_provider", return_value="gemini"), + patch("yt_study.ui.wizard.select_provider", return_value="gemini"), patch( - "yt_study.setup_wizard.select_model", return_value="gemini/gemini-pro" + "yt_study.ui.wizard.select_model", + return_value="gemini/gemini-pro", ), - patch("yt_study.setup_wizard.get_api_key", return_value="new-key"), + patch("yt_study.ui.wizard.get_api_key", return_value="new-key"), patch("rich.prompt.Prompt.ask", side_effect=["/custom/out", "10"]), - patch("yt_study.setup_wizard.save_config") as mock_save, + patch("yt_study.ui.wizard.save_config") as mock_save, ): config = run_setup_wizard(force=True) @@ -246,7 +249,10 @@ def test_run_setup_wizard_full_flow(self): def test_run_setup_wizard_skip_existing(self): """Test skipping setup if config exists.""" with ( - patch("yt_study.setup_wizard.load_config", return_value={"exists": "true"}), + patch( + "yt_study.ui.wizard.load_config", + return_value={"exists": "true"}, + ), patch("rich.prompt.Confirm.ask", return_value=False), ): # Do not reconfigure config = run_setup_wizard(force=False) diff --git a/tests/test_youtube/test_metadata.py b/tests/test_youtube/test_metadata.py index 2e2749b..8866714 100644 --- a/tests/test_youtube/test_metadata.py +++ b/tests/test_youtube/test_metadata.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, PropertyMock -from yt_study.youtube.metadata import ( +from yt_study.core.youtube.metadata import ( get_playlist_info, get_video_chapters, get_video_duration, diff --git a/tests/test_youtube/test_parser.py b/tests/test_youtube/test_parser.py index 33d9fef..be3362c 100644 --- a/tests/test_youtube/test_parser.py +++ b/tests/test_youtube/test_parser.py @@ -2,7 +2,7 @@ import pytest -from yt_study.youtube.parser import ( +from yt_study.core.youtube.parser import ( extract_playlist_id, extract_video_id, parse_youtube_url, diff --git a/tests/test_youtube/test_playlist.py b/tests/test_youtube/test_playlist.py index 20e3b2a..9ad1993 100644 --- a/tests/test_youtube/test_playlist.py +++ b/tests/test_youtube/test_playlist.py @@ -4,7 +4,7 @@ import pytest -from yt_study.youtube.playlist import PlaylistError, extract_playlist_videos +from yt_study.core.youtube.playlist import PlaylistError, extract_playlist_videos class TestPlaylistExtraction: diff --git a/tests/test_youtube/test_transcript.py b/tests/test_youtube/test_transcript.py index 5379d4b..f1dc78c 100644 --- a/tests/test_youtube/test_transcript.py +++ b/tests/test_youtube/test_transcript.py @@ -5,8 +5,8 @@ import pytest from youtube_transcript_api._errors import NoTranscriptFound, VideoUnavailable -from yt_study.youtube.metadata import VideoChapter -from yt_study.youtube.transcript import ( +from yt_study.core.youtube.metadata import VideoChapter +from yt_study.core.youtube.transcript import ( TranscriptError, VideoTranscript, fetch_transcript, @@ -21,7 +21,7 @@ class TestFetchTranscript: def mock_transcript_api_instance(self, mocker): """Mock the YouTubeTranscriptApi class and its instance.""" # Patch the class - mock_cls = mocker.patch("yt_study.youtube.transcript.YouTubeTranscriptApi") + mock_cls = mocker.patch("yt_study.core.youtube.transcript.YouTubeTranscriptApi") # The instance returned by constructor mock_instance = mock_cls.return_value return mock_instance diff --git a/uv.lock b/uv.lock index e192735..de01ef1 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10" [[package]]