diff --git a/pyproject.toml b/pyproject.toml index 66a59efd..f72ba036 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ dependencies = [ "packaging>=25.0", "paho-mqtt", "pandas", + "passlib[argon2]>=1.7.4", "plotly", "psutil>=7.0.0", "pydub>=0.25.1", @@ -79,6 +80,7 @@ dependencies = [ "sqladmin>=0.21.0", "sqlalchemy", "sqlmodel>=0.0.24", + "starsessions[redis]>=2.2.1", "structlog>=25.4.0", "suntime", "tqdm>=4.67.1", diff --git a/src/birdnetpi/utils/auth.py b/src/birdnetpi/utils/auth.py new file mode 100644 index 00000000..145e130c --- /dev/null +++ b/src/birdnetpi/utils/auth.py @@ -0,0 +1,182 @@ +"""Authentication utilities for BirdNET-Pi admin interface. + +Provides session-based authentication using Starlette's built-in +authentication system with Redis-backed sessions. +""" + +from collections.abc import Awaitable, Callable +from datetime import datetime + +from passlib.context import CryptContext +from pydantic import BaseModel +from starlette.authentication import ( + AuthCredentials, + AuthenticationBackend, + SimpleUser, +) +from starlette.requests import HTTPConnection +from starlette.responses import RedirectResponse +from starsessions import load_session + +from birdnetpi.system.path_resolver import PathResolver + +# Password hashing context using Argon2 +pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") + + +def require_admin_relative( + redirect_path: str = "/admin/login", +) -> Callable[[Callable[..., Awaitable[object]]], Callable[..., Awaitable[object]]]: + """Create authentication decorator that uses relative URLs for redirects. + + Unlike Starlette's @requires which generates absolute URLs, this decorator + uses relative paths to avoid issues with proxies and URL parsing. + + Args: + redirect_path: Relative path to redirect to if not authenticated + + Returns: + Decorator function that wraps route handlers + """ + from functools import wraps + from urllib.parse import urlencode + + def decorator( + func: Callable[..., Awaitable[object]], + ) -> Callable[..., Awaitable[object]]: + @wraps(func) + async def wrapper(request: HTTPConnection, *args: object, **kwargs: object) -> object: + # Check if user has required authentication scope + # This uses request.auth which works with or without AuthenticationMiddleware + if "authenticated" not in request.auth.scopes: + # Build relative redirect URL with next parameter + next_qparam = urlencode({"next": str(request.url.path)}) + if request.url.query: + next_qparam = urlencode({"next": f"{request.url.path}?{request.url.query}"}) + + redirect_url = f"{redirect_path}?{next_qparam}" + return RedirectResponse(url=redirect_url, status_code=303) + + return await func(request, *args, **kwargs) + + return wrapper + + return decorator + + +# Syntactic sugar for common authentication requirement +# Usage: @require_admin decorator on admin view routes +require_admin = require_admin_relative() + + +class AdminUser(BaseModel): + """Admin user model for file-based storage.""" + + username: str + password_hash: str + created_at: datetime + + +class AuthService: + """Handles admin user file operations and password hashing. + + Stores a single admin user in a JSON file with permissions set to 0600. + """ + + def __init__(self, path_resolver: PathResolver) -> None: + """Initialize auth service. + + Args: + path_resolver: PathResolver instance for determining file paths + """ + self.admin_file = path_resolver.get_data_dir() / "admin_user.json" + + def load_admin_user(self) -> AdminUser | None: + """Load admin user from JSON file. + + Returns: + AdminUser if file exists and is valid, None otherwise + """ + if not self.admin_file.exists(): + return None + + import json + + try: + with open(self.admin_file) as f: + data = json.load(f) + return AdminUser(**data) + except (json.JSONDecodeError, ValueError): + return None + + def save_admin_user(self, username: str, password: str) -> None: + """Hash password and save to JSON with 0600 permissions. + + Args: + username: Admin username + password: Plain text password (will be hashed) + """ + import json + + admin = AdminUser( + username=username, password_hash=pwd_context.hash(password), created_at=datetime.now() + ) + + # Write to file + with open(self.admin_file, "w") as f: + json.dump(admin.model_dump(), f, default=str, indent=2) + + # Set restrictive permissions (owner read/write only) + self.admin_file.chmod(0o600) + + def verify_password(self, password: str, password_hash: str) -> bool: + """Verify password against hash. + + Args: + password: Plain text password to verify + password_hash: Argon2 hash to verify against + + Returns: + True if password matches, False otherwise + """ + return pwd_context.verify(password, password_hash) + + def admin_exists(self) -> bool: + """Check if admin user file exists. + + Returns: + True if admin_user.json exists, False otherwise + """ + return self.admin_file.exists() + + +class SessionAuthBackend(AuthenticationBackend): + """Session-based authentication backend for Starlette. + + Checks for username in session and returns appropriate credentials. + """ + + async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, SimpleUser] | None: + """Authenticate request based on session data. + + Called by AuthenticationMiddleware on every request. Explicitly loads + session from starsessions middleware before accessing it. + + Args: + conn: HTTP connection (request or WebSocket) + + Returns: + Tuple of (AuthCredentials, SimpleUser) if authenticated, + None if not authenticated + """ + # Load session from starsessions middleware + await load_session(conn) + + # Get username from session + username = conn.session.get("username") + if not username: + return None # Not authenticated + + # Return authenticated user with "authenticated" scope + # The scope is used by @requires decorator for authorization + return AuthCredentials(["authenticated"]), SimpleUser(username) diff --git a/src/birdnetpi/web/core/container.py b/src/birdnetpi/web/core/container.py index 8fbed5a4..68cc903d 100644 --- a/src/birdnetpi/web/core/container.py +++ b/src/birdnetpi/web/core/container.py @@ -27,6 +27,7 @@ from birdnetpi.system.log_reader import LogReaderService from birdnetpi.system.path_resolver import PathResolver from birdnetpi.system.system_control import SystemControlService +from birdnetpi.utils.auth import AuthService from birdnetpi.utils.cache import Cache from birdnetpi.web.core.config import get_config @@ -128,6 +129,19 @@ class Container(containers.DeclarativeContainer): enable_cache_warming=True, ) + # Authentication services + auth_service = providers.Singleton( + AuthService, + path_resolver=path_resolver, + ) + + # Redis client for session storage - singleton + redis_client = providers.Singleton( + lambda: __import__("redis.asyncio", fromlist=["Redis"]).Redis.from_url( + "redis://127.0.0.1:6379" + ), + ) + # Core business services - singletons file_manager = providers.Singleton( FileManager, diff --git a/src/birdnetpi/web/core/factory.py b/src/birdnetpi/web/core/factory.py index e5ccad6d..1bde7646 100644 --- a/src/birdnetpi/web/core/factory.py +++ b/src/birdnetpi/web/core/factory.py @@ -3,18 +3,24 @@ from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse +from starlette.middleware.authentication import AuthenticationMiddleware +from starsessions import SessionMiddleware +from starsessions.stores.redis import RedisStore from birdnetpi.config.manager import ConfigManager from birdnetpi.i18n.translation_manager import setup_jinja2_i18n from birdnetpi.system.status import SystemInspector +from birdnetpi.utils.auth import SessionAuthBackend from birdnetpi.utils.language import get_user_language from birdnetpi.web.core.container import Container from birdnetpi.web.core.lifespan import lifespan from birdnetpi.web.middleware.i18n import LanguageMiddleware from birdnetpi.web.middleware.request_logging import StructuredRequestLoggingMiddleware +from birdnetpi.web.middleware.setup_redirect import SetupRedirectMiddleware from birdnetpi.web.middleware.update_banner import add_update_status_to_templates from birdnetpi.web.routers import ( analysis_api_routes, + auth_routes, detections_api_routes, health_api_routes, i18n_api_routes, @@ -75,6 +81,37 @@ def create_app() -> FastAPI: expose_headers=["*"], # Expose all headers including Content-Type ) + # Authentication and session middleware + # NOTE: Middleware is stacked in reverse order - last added runs first! + # Desired execution order: Session → Auth → SetupRedirect → App + # So add in reverse: SetupRedirect, Auth, Session + + # 1. Setup redirect (added first, runs last before app) + auth_service = container.auth_service() + app.add_middleware(SetupRedirectMiddleware, auth_service=auth_service) + + # 2. Authentication (added second, runs after session loads) + app.add_middleware( + AuthenticationMiddleware, + backend=SessionAuthBackend(), + ) + + # 3. Session middleware (added last, runs first to load session) + redis_client = container.redis_client() + session_store = RedisStore( + connection=redis_client, + prefix="birdnetpi:", + gc_ttl=86400, # 24 hours + ) + app.add_middleware( + SessionMiddleware, + store=session_store, + lifetime=86400, # 24 hours + rolling=True, # Extend session on each request + cookie_https_only=False, # TODO: Enable in production with HTTPS + cookie_name="birdnetpi_session", + ) + # Add LanguageMiddleware app.add_middleware(LanguageMiddleware) @@ -102,6 +139,7 @@ def create_app() -> FastAPI: modules=[ "birdnetpi.web.core.factory", # Wire factory for root route "birdnetpi.web.routers.analysis_api_routes", + "birdnetpi.web.routers.auth_routes", "birdnetpi.web.routers.detections_api_routes", "birdnetpi.web.routers.health_api_routes", "birdnetpi.web.routers.i18n_api_routes", # Wire i18n API routes @@ -125,6 +163,13 @@ def create_app() -> FastAPI: # === API Routes (included in documentation) === + # Authentication routes (setup, login, logout) + app.include_router( + auth_routes.router, + tags=["Authentication"], + include_in_schema=False, # Exclude from API docs + ) + # Analysis API routes for progressive loading app.include_router(analysis_api_routes.router, prefix="/api", tags=["Analysis API"]) @@ -205,6 +250,12 @@ def create_app() -> FastAPI: # Database administration interface sqladmin_view_routes.setup_sqladmin(app) + # Cleanup Redis connection on shutdown + @app.on_event("shutdown") + async def shutdown() -> None: + """Clean up resources on application shutdown.""" + await redis_client.close() + # Root route (excluded from API documentation) @app.get("/", response_class=HTMLResponse, include_in_schema=False) async def read_root(request: Request) -> HTMLResponse: diff --git a/src/birdnetpi/web/middleware/setup_redirect.py b/src/birdnetpi/web/middleware/setup_redirect.py new file mode 100644 index 00000000..df0bad83 --- /dev/null +++ b/src/birdnetpi/web/middleware/setup_redirect.py @@ -0,0 +1,64 @@ +"""Setup redirect middleware for BirdNET-Pi. + +Redirects all requests to the setup wizard if no admin user exists. +""" + +from collections.abc import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import RedirectResponse, Response +from starlette.types import ASGIApp + +from birdnetpi.utils.auth import AuthService + + +class SetupRedirectMiddleware(BaseHTTPMiddleware): + """Redirect to setup wizard if no admin user exists. + + This middleware checks if an admin user has been created. If not, + it redirects all requests to /admin/setup except for: + - The setup page itself + - The login page + - Static files + - Health check endpoints + """ + + def __init__(self, app: ASGIApp, auth_service: AuthService): + """Initialize middleware with auth service. + + Args: + app: ASGI application + auth_service: AuthService instance for checking admin existence + """ + super().__init__(app) + self.auth_service = auth_service + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """Process request and redirect to setup if needed. + + Args: + request: Incoming HTTP request + call_next: Next middleware/endpoint in chain + + Returns: + Response from next handler or redirect to setup + """ + # Paths that should not trigger setup redirect + exempt_paths = [ + "/admin/setup", + "/admin/login", + "/static/", + "/api/health", + ] + + # Allow exempt paths through + if any(request.url.path.startswith(path) for path in exempt_paths): + return await call_next(request) + + # Check if admin user exists + if not self.auth_service.admin_exists(): + return RedirectResponse(url="/admin/setup", status_code=303) + + # Admin exists, continue normally + return await call_next(request) diff --git a/src/birdnetpi/web/routers/auth_routes.py b/src/birdnetpi/web/routers/auth_routes.py new file mode 100644 index 00000000..a7509494 --- /dev/null +++ b/src/birdnetpi/web/routers/auth_routes.py @@ -0,0 +1,211 @@ +"""Authentication routes for admin setup and login.""" + +from typing import Annotated + +from dependency_injector.wiring import Provide, inject +from fastapi import APIRouter, Depends, Form, Request +from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.templating import Jinja2Templates +from starsessions import load_session +from starsessions.session import regenerate_session_id + +from birdnetpi.config import BirdNETConfig +from birdnetpi.i18n.translation_manager import TranslationManager +from birdnetpi.system.status import SystemInspector +from birdnetpi.utils.auth import AuthService +from birdnetpi.utils.language import get_user_language +from birdnetpi.web.core.container import Container + +router = APIRouter() + + +@router.get("/admin/setup", response_class=HTMLResponse, response_model=None) +@inject +async def setup_page( + request: Request, + templates: Annotated[Jinja2Templates, Depends(Provide[Container.templates])], + auth_service: Annotated[AuthService, Depends(Provide[Container.auth_service])], +) -> HTMLResponse | RedirectResponse: + """Show setup wizard page. + + If admin user already exists, redirects to home page. + """ + if auth_service.admin_exists(): + return RedirectResponse(url="/", status_code=303) + + return templates.TemplateResponse( + "admin/setup.html.j2", {"request": request, "prefill_username": "admin"} + ) + + +@router.post("/admin/setup") +@inject +async def create_admin( + request: Request, + auth_service: Annotated[AuthService, Depends(Provide[Container.auth_service])], + username: str = Form(...), + password: str = Form(...), +) -> RedirectResponse: + """Create admin user and log them in. + + Saves admin user with hashed password, creates session, and redirects to home. + """ + # Save admin user (password is automatically hashed) + auth_service.save_admin_user(username, password) + + # Regenerate session ID to prevent session fixation attacks + regenerate_session_id(request) + + # Store username in session + request.session["username"] = username + + # Redirect to home page + return RedirectResponse(url="/", status_code=303) + + +@router.get("/admin/login", response_class=HTMLResponse, name="login") +@inject +async def login_page( + request: Request, + templates: Annotated[Jinja2Templates, Depends(Provide[Container.templates])], + config: Annotated[BirdNETConfig, Depends(Provide[Container.config])], + translation_manager: Annotated[ + TranslationManager, Depends(Provide[Container.translation_manager]) + ], +) -> HTMLResponse: + """Show login page.""" + # Get user language + language = get_user_language(request, config) + _ = translation_manager.get_translation(language).gettext + + # Create context with all required base template variables + context = { + "request": request, + "error": None, + "config": config, + "language": language, + "system_status": {"device_name": SystemInspector.get_device_name()}, + "page_name": _("Administrator Login"), + "active_page": "login", + "model_update_date": None, + } + + return templates.TemplateResponse("admin/login.html.j2", context) + + +@router.post("/admin/login", response_model=None) +@inject +async def login( + request: Request, + templates: Annotated[Jinja2Templates, Depends(Provide[Container.templates])], + auth_service: Annotated[AuthService, Depends(Provide[Container.auth_service])], + config: Annotated[BirdNETConfig, Depends(Provide[Container.config])], + translation_manager: Annotated[ + TranslationManager, Depends(Provide[Container.translation_manager]) + ], + username: str = Form(...), + password: str = Form(...), +) -> HTMLResponse | RedirectResponse: + """Handle login form submission. + + Verifies credentials and creates session on success. Returns to login + page with error on failure. + """ + # Load admin user + admin = auth_service.load_admin_user() + if not admin or admin.username != username: + # Get user language for error template + language = get_user_language(request, config) + _ = translation_manager.get_translation(language).gettext + + context = { + "request": request, + "error": _("Invalid credentials"), + "config": config, + "language": language, + "system_status": {"device_name": SystemInspector.get_device_name()}, + "page_name": _("Administrator Login"), + "active_page": "login", + "model_update_date": None, + } + return templates.TemplateResponse("admin/login.html.j2", context) + + # Verify password + if not auth_service.verify_password(password, admin.password_hash): + # Get user language for error template + language = get_user_language(request, config) + _ = translation_manager.get_translation(language).gettext + + context = { + "request": request, + "error": _("Invalid credentials"), + "config": config, + "language": language, + "system_status": {"device_name": SystemInspector.get_device_name()}, + "page_name": _("Administrator Login"), + "active_page": "login", + "model_update_date": None, + } + return templates.TemplateResponse("admin/login.html.j2", context) + + # Regenerate session ID + regenerate_session_id(request) + + # Store username in session + request.session["username"] = username + + # Redirect to original URL if specified, otherwise home + next_url = request.query_params.get("next", "/") + + # Security: Only allow relative URLs to prevent open redirects + # Our custom @require_admin decorator always uses relative URLs + if next_url.startswith(("http://", "https://", "//")): + # External or protocol-relative URL - ignore and go to home + next_url = "/" + + return RedirectResponse(url=next_url, status_code=303) + + +@router.get("/admin/logout") +async def logout(request: Request) -> RedirectResponse: + """Handle logout. + + Clears session and redirects to login page. + """ + # Load session before accessing it + await load_session(request) + request.session.clear() + return RedirectResponse(url="/admin/login", status_code=303) + + +# SQLAdmin login/logout redirects +@router.get("/admin/database/login", include_in_schema=False) +async def database_login_redirect() -> RedirectResponse: + """Redirect SQLAdmin login to BirdNET-Pi login page.""" + return RedirectResponse(url="/admin/login?next=/admin/database", status_code=303) + + +@router.post("/admin/database/login", include_in_schema=False) +async def database_login_submit_redirect() -> RedirectResponse: + """Redirect SQLAdmin login form submission to BirdNET-Pi login page.""" + return RedirectResponse(url="/admin/login?next=/admin/database", status_code=303) + + +@router.get("/admin/database/logout", include_in_schema=False) +async def database_logout_redirect() -> RedirectResponse: + """Redirect SQLAdmin logout to BirdNET-Pi logout.""" + return RedirectResponse(url="/admin/logout", status_code=303) + + +# API endpoints for authentication status +@router.get("/api/auth/status") +async def auth_status(request: Request) -> dict[str, bool | str | None]: + """Check authentication status. + + Returns: + Dict with authenticated boolean and username if authenticated + """ + return { + "authenticated": request.user.is_authenticated, + "username": request.user.display_name if request.user.is_authenticated else None, + } diff --git a/src/birdnetpi/web/routers/logs_api_routes.py b/src/birdnetpi/web/routers/logs_api_routes.py index bec79833..ccda8814 100644 --- a/src/birdnetpi/web/routers/logs_api_routes.py +++ b/src/birdnetpi/web/routers/logs_api_routes.py @@ -7,10 +7,11 @@ from typing import Annotated, Any from dependency_injector.wiring import Provide, inject -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Depends, Query, Request from fastapi.responses import StreamingResponse from birdnetpi.system.log_reader import LogReaderService +from birdnetpi.utils.auth import require_admin from birdnetpi.web.core.container import Container from birdnetpi.web.models.logs import LOG_LEVELS, LogEntry from birdnetpi.web.models.services import LogsResponse @@ -20,8 +21,10 @@ @router.get("/logs", response_model=LogsResponse) +@require_admin @inject async def get_logs( + request: Request, log_reader: Annotated[LogReaderService, Depends(Provide[Container.log_reader])], start_time: Annotated[datetime | None, Query(description="Start of time range")] = None, end_time: Annotated[datetime | None, Query(description="End of time range")] = None, @@ -33,10 +36,11 @@ async def get_logs( fetched and streamed logs. Args: + request: FastAPI request object (required by authentication decorator) + log_reader: Injected log reader service start_time: Start of time range end_time: End of time range limit: Maximum number of entries - log_reader: Injected log reader service Returns: Dictionary with logs and metadata @@ -120,8 +124,10 @@ async def get_logs( @router.get("/logs/stream") +@require_admin @inject async def stream_logs( + request: Request, log_reader: Annotated[LogReaderService, Depends(Provide[Container.log_reader])], ) -> StreamingResponse: """Stream logs using Server-Sent Events (SSE). @@ -130,6 +136,7 @@ async def stream_logs( fetched and streamed logs. Args: + request: FastAPI request object (required by authentication decorator) log_reader: Injected log reader service Returns: @@ -204,7 +211,8 @@ async def event_generator() -> AsyncIterator[str]: @router.get("/logs/levels") -async def get_log_levels() -> list[dict[str, Any]]: +@require_admin +async def get_log_levels(request: Request) -> list[dict[str, Any]]: """Get available log levels with display information. Returns: diff --git a/src/birdnetpi/web/routers/logs_view_routes.py b/src/birdnetpi/web/routers/logs_view_routes.py index ee20fa63..af7fcd5a 100644 --- a/src/birdnetpi/web/routers/logs_view_routes.py +++ b/src/birdnetpi/web/routers/logs_view_routes.py @@ -13,6 +13,7 @@ from birdnetpi.system.status import SystemInspector from birdnetpi.system.system_control import SERVICES_CONFIG, SystemControlService from birdnetpi.system.system_utils import SystemUtils +from birdnetpi.utils.auth import require_admin from birdnetpi.utils.language import get_user_language from birdnetpi.web.core.container import Container from birdnetpi.web.models.logs import LOG_LEVELS @@ -23,6 +24,7 @@ @router.get("/logs", response_class=HTMLResponse) +@require_admin @inject async def view_logs( request: Request, diff --git a/src/birdnetpi/web/routers/multimedia_view_routes.py b/src/birdnetpi/web/routers/multimedia_view_routes.py index bd6305df..aa992c08 100644 --- a/src/birdnetpi/web/routers/multimedia_view_routes.py +++ b/src/birdnetpi/web/routers/multimedia_view_routes.py @@ -10,6 +10,7 @@ from birdnetpi.config import BirdNETConfig from birdnetpi.i18n.translation_manager import TranslationManager from birdnetpi.system.status import SystemInspector +from birdnetpi.utils.auth import require_admin from birdnetpi.utils.language import get_user_language from birdnetpi.web.core.container import Container from birdnetpi.web.models.template_contexts import LivestreamPageContext @@ -18,6 +19,7 @@ @router.get("/livestream", response_class=HTMLResponse) +@require_admin @inject async def get_livestream( request: Request, diff --git a/src/birdnetpi/web/routers/services_view_routes.py b/src/birdnetpi/web/routers/services_view_routes.py index 96551689..680fda36 100644 --- a/src/birdnetpi/web/routers/services_view_routes.py +++ b/src/birdnetpi/web/routers/services_view_routes.py @@ -14,6 +14,7 @@ from birdnetpi.system.status import SystemInspector from birdnetpi.system.system_control import SERVICES_CONFIG, SystemControlService from birdnetpi.system.system_utils import SystemUtils +from birdnetpi.utils.auth import require_admin from birdnetpi.utils.language import get_user_language from birdnetpi.web.core.container import Container from birdnetpi.web.models.services import ServiceConfig, format_uptime @@ -104,6 +105,7 @@ def _get_system_info( @router.get("/admin/services", response_class=HTMLResponse) +@require_admin @inject async def services_view( request: Request, diff --git a/src/birdnetpi/web/routers/settings_api_routes.py b/src/birdnetpi/web/routers/settings_api_routes.py index 75625384..46fd6d2b 100644 --- a/src/birdnetpi/web/routers/settings_api_routes.py +++ b/src/birdnetpi/web/routers/settings_api_routes.py @@ -4,10 +4,11 @@ import yaml from dependency_injector.wiring import Provide, inject -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Request from birdnetpi.config import ConfigManager from birdnetpi.system.path_resolver import PathResolver +from birdnetpi.utils.auth import require_admin from birdnetpi.web.core.container import Container from birdnetpi.web.models.admin import SaveConfigResponse, ValidationResponse, YAMLConfigRequest @@ -49,8 +50,10 @@ def _validate_yaml_config_impl(yaml_content: str, path_resolver: PathResolver) - @router.post("/settings/validate", response_model=ValidationResponse) +@require_admin @inject async def validate_yaml_config( + request: Request, config_request: YAMLConfigRequest, path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], ) -> ValidationResponse: @@ -60,8 +63,10 @@ async def validate_yaml_config( @router.post("/settings/save", response_model=SaveConfigResponse) +@require_admin @inject async def save_yaml_config( + request: Request, config_request: YAMLConfigRequest, path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], ) -> SaveConfigResponse: diff --git a/src/birdnetpi/web/routers/settings_view_routes.py b/src/birdnetpi/web/routers/settings_view_routes.py index ca7d4843..c05e1a12 100644 --- a/src/birdnetpi/web/routers/settings_view_routes.py +++ b/src/birdnetpi/web/routers/settings_view_routes.py @@ -20,6 +20,7 @@ from birdnetpi.system.log_reader import LogReaderService from birdnetpi.system.path_resolver import PathResolver from birdnetpi.system.status import SystemInspector +from birdnetpi.utils.auth import require_admin from birdnetpi.utils.helpers import prefer from birdnetpi.utils.language import get_user_language from birdnetpi.web.core.container import Container @@ -41,6 +42,7 @@ async def read_admin() -> dict[str, str]: # Settings Management @router.get("/settings", response_class=HTMLResponse) +@require_admin @inject async def get_settings_view( request: Request, diff --git a/src/birdnetpi/web/routers/sqladmin_view_routes.py b/src/birdnetpi/web/routers/sqladmin_view_routes.py index 20648290..6e3d60fb 100644 --- a/src/birdnetpi/web/routers/sqladmin_view_routes.py +++ b/src/birdnetpi/web/routers/sqladmin_view_routes.py @@ -1,12 +1,41 @@ """SQLAdmin configuration and setup for database administration interface.""" +import logging + from fastapi import FastAPI from sqladmin import Admin, ModelView +from sqladmin.authentication import AuthenticationBackend +from starlette.requests import Request from birdnetpi.detections.models import AudioFile, Detection from birdnetpi.location.models import Weather +from birdnetpi.utils.auth import AuthService from birdnetpi.web.core.container import Container +logger = logging.getLogger(__name__) + + +class AdminAuthBackend(AuthenticationBackend): + """Authentication backend for SQLAdmin using BirdNET-Pi authentication.""" + + def __init__(self, secret_key: str, auth_service: AuthService): + """Initialize with secret key and auth service.""" + super().__init__(secret_key) + self.auth_service = auth_service + + async def login(self, request: Request) -> bool: + """Not used - login is handled by BirdNET-Pi auth routes.""" + return False + + async def logout(self, request: Request) -> bool: + """Not used - logout is handled by BirdNET-Pi auth routes.""" + return False + + async def authenticate(self, request: Request) -> bool: + """Check if user is authenticated via Starlette middleware.""" + # Starlette's AuthenticationMiddleware sets request.user + return request.user.is_authenticated + class DetectionAdmin(ModelView, model=Detection): """Admin interface for Detection model.""" @@ -66,16 +95,20 @@ def setup_sqladmin(app: FastAPI) -> Admin: Returns: Configured Admin instance """ - # Get database engine from the DI container + # Get database engine and auth service from the DI container container = Container() core_database = container.core_database() + auth_service = container.auth_service() - # Create admin with custom configuration + # Create admin with custom configuration and authentication admin = Admin( app, core_database.async_engine, base_url="/admin/database", title="BirdNET-Pi Database Admin", + authentication_backend=AdminAuthBackend( + secret_key="not-used-we-use-starlette-auth", auth_service=auth_service + ), ) # Register model views diff --git a/src/birdnetpi/web/routers/system_api_routes.py b/src/birdnetpi/web/routers/system_api_routes.py index 655b87c7..0d1b436f 100644 --- a/src/birdnetpi/web/routers/system_api_routes.py +++ b/src/birdnetpi/web/routers/system_api_routes.py @@ -7,12 +7,13 @@ from typing import Annotated from dependency_injector.wiring import Provide, inject -from fastapi import APIRouter, Depends, HTTPException, Path +from fastapi import APIRouter, Depends, HTTPException, Path, Request from birdnetpi.detections.queries import DetectionQueryService from birdnetpi.system.status import SystemInspector from birdnetpi.system.system_control import SERVICES_CONFIG, SystemControlService from birdnetpi.system.system_utils import SystemUtils +from birdnetpi.utils.auth import require_admin from birdnetpi.web.core.container import Container from birdnetpi.web.models.services import ( ConfigReloadResponse, @@ -38,8 +39,10 @@ @router.get("/hardware/status", response_model=HardwareStatusResponse) +@require_admin @inject async def get_hardware_status( + request: Request, detection_query_service: Annotated[ DetectionQueryService, Depends(Provide[Container.detection_query_service]) ], @@ -106,8 +109,10 @@ async def get_hardware_status( @router.get("/services/status", response_model=ServicesStatusResponse) +@require_admin @inject async def get_services_status( + request: Request, system_control: Annotated[ SystemControlService, Depends(Provide[Container.system_control_service]) ], @@ -156,8 +161,10 @@ async def get_services_status( @router.post("/services/reload-config", response_model=ConfigReloadResponse) +@require_admin @inject async def reload_configuration( + request: Request, system_control: Annotated[ SystemControlService, Depends(Provide[Container.system_control_service]) ], @@ -180,8 +187,10 @@ async def reload_configuration( @router.get("/services/info", response_model=SystemInfo) +@require_admin @inject async def get_system_info( + request: Request, system_control: Annotated[ SystemControlService, Depends(Provide[Container.system_control_service]) ], @@ -203,9 +212,11 @@ async def get_system_info( @router.post("/services/reboot", response_model=SystemRebootResponse) +@require_admin @inject async def reboot_system( - request: SystemRebootRequest, + request: Request, + reboot_request: SystemRebootRequest, system_control: Annotated[ SystemControlService, Depends(Provide[Container.system_control_service]) ], @@ -215,7 +226,7 @@ async def reboot_system( Requires confirmation to prevent accidental reboots. Only available if the deployment supports it. """ - if not request.confirm: + if not reboot_request.confirm: return SystemRebootResponse( success=False, message="Reboot requires confirmation", @@ -250,8 +261,10 @@ async def reboot_system( @router.get("/services") +@require_admin @inject async def get_services_list( + request: Request, system_control: Annotated[ SystemControlService, Depends(Provide[Container.system_control_service]) ], @@ -285,11 +298,13 @@ async def get_services_list( @router.post("/services/{service_name}/{action}", response_model=ServiceActionResponse) +@require_admin @inject async def perform_service_action( + request: Request, service_name: Annotated[str, Path(description="Name of the service")], action: Annotated[str, Path(pattern="^(start|stop|restart)$", description="Action to perform")], - request: ServiceActionRequest, + action_request: ServiceActionRequest, system_control: Annotated[ SystemControlService, Depends(Provide[Container.system_control_service]) ], @@ -305,7 +320,7 @@ async def perform_service_action( service_config = next((s for s in service_configs if s.name == service_name), None) if service_config and service_config.critical and action in ["restart", "stop"]: - if not request.confirm: + if not action_request.confirm: return ServiceActionResponse( success=False, message=( diff --git a/src/birdnetpi/web/routers/update_api_routes.py b/src/birdnetpi/web/routers/update_api_routes.py index 09279c3d..639ce8bd 100644 --- a/src/birdnetpi/web/routers/update_api_routes.py +++ b/src/birdnetpi/web/routers/update_api_routes.py @@ -5,7 +5,7 @@ from typing import Annotated, Any from dependency_injector.wiring import Provide, inject -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Request from birdnetpi.config import BirdNETConfig from birdnetpi.config.manager import ConfigManager @@ -14,6 +14,7 @@ from birdnetpi.system.git_operations import GitOperationsService from birdnetpi.system.path_resolver import PathResolver from birdnetpi.system.system_utils import SystemUtils +from birdnetpi.utils.auth import require_admin from birdnetpi.utils.cache import Cache from birdnetpi.web.core.container import Container from birdnetpi.web.models.update import ( @@ -34,9 +35,11 @@ @router.post("/check") +@require_admin @inject async def check_for_updates( - request: UpdateCheckRequest, + request: Request, + update_request: UpdateCheckRequest, cache: Annotated[Cache, Depends(Provide[Container.cache_service])], ) -> UpdateStatusResponse: """Check for available updates. @@ -48,7 +51,7 @@ async def check_for_updates( # Queue the check request for the update daemon cache.set( "update:request", - {"action": "check", "force": request.force}, + {"action": "check", "force": update_request.force}, ttl=60, # Request expires after 1 minute ) @@ -69,8 +72,10 @@ async def check_for_updates( @router.get("/status") +@require_admin @inject async def get_update_status( + request: Request, cache: Annotated[Cache, Depends(Provide[Container.cache_service])], ) -> UpdateStatusResponse: """Get current update status from cache. @@ -93,9 +98,11 @@ async def get_update_status( @router.post("/apply") +@require_admin @inject async def apply_update( - request: UpdateApplyRequest, + request: Request, + update_request: UpdateApplyRequest, cache: Annotated[Cache, Depends(Provide[Container.cache_service])], ) -> UpdateActionResponse: """Apply a system update. @@ -119,14 +126,14 @@ async def apply_update( "update:request", { "action": "apply", - "version": request.version, - "dry_run": request.dry_run, + "version": update_request.version, + "dry_run": update_request.dry_run, }, ttl=300, # Request expires after 5 minutes ) # Construct status message (not SQL) - status_message = f"Update to version {request.version} has been queued" # nosemgrep + status_message = f"Update to version {update_request.version} has been queued" # nosemgrep return UpdateActionResponse( success=True, message=status_message, @@ -138,8 +145,10 @@ async def apply_update( @router.get("/result") +@require_admin @inject async def get_update_result( + request: Request, cache: Annotated[Cache, Depends(Provide[Container.cache_service])], ) -> dict[str, Any]: """Get the result of the last update operation. @@ -163,8 +172,10 @@ async def get_update_result( @router.delete("/cancel") +@require_admin @inject async def cancel_update( + request: Request, cache: Annotated[Cache, Depends(Provide[Container.cache_service])], ) -> UpdateActionResponse: """Cancel a pending update request. @@ -202,16 +213,19 @@ async def cancel_update( @router.post("/config/git") +@require_admin @inject async def update_git_config( - request: GitConfigRequest, + request: Request, + git_config: GitConfigRequest, config: Annotated[BirdNETConfig, Depends(Provide[Container.config])], path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], ) -> UpdateActionResponse: """Update git configuration for system updates. Args: - request: Git configuration settings + request: FastAPI request object (required by authentication decorator) + git_config: Git configuration settings config: Current BirdNET configuration path_resolver: Path resolver for configuration paths @@ -220,8 +234,8 @@ async def update_git_config( """ try: # Update git settings on the config object - config.updates.git_remote = request.git_remote - config.updates.git_branch = request.git_branch + config.updates.git_remote = git_config.git_remote + config.updates.git_branch = git_config.git_branch # Save configuration using ConfigManager config_manager = ConfigManager(path_resolver) @@ -229,13 +243,13 @@ async def update_git_config( logger.info( "Updated git configuration: remote=%s, branch=%s", - request.git_remote, - request.git_branch, + git_config.git_remote, + git_config.git_branch, ) return UpdateActionResponse( success=True, - message=f"Git configuration updated: {request.git_remote}/{request.git_branch}", + message=f"Git configuration updated: {git_config.git_remote}/{git_config.git_branch}", ) except ValueError as e: @@ -251,8 +265,10 @@ async def update_git_config( @router.get("/git/remotes") +@require_admin @inject async def list_git_remotes( + request: Request, path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], ) -> GitRemoteListResponse: """List all configured git remotes. @@ -285,15 +301,18 @@ async def list_git_remotes( @router.post("/git/remotes") +@require_admin @inject async def add_git_remote( - request: GitRemoteRequest, + request: Request, + remote_request: GitRemoteRequest, path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], ) -> UpdateActionResponse: """Add a new git remote. Args: - request: Remote name and URL + request: FastAPI request object (required by authentication decorator) + remote_request: Remote name and URL path_resolver: Path resolver for repository location Returns: @@ -309,11 +328,11 @@ async def add_git_remote( ) git_service = GitOperationsService(path_resolver) - git_service.add_remote(request.name, request.url) + git_service.add_remote(remote_request.name, remote_request.url) return UpdateActionResponse( success=True, - message=f"Git remote '{request.name}' added successfully", + message=f"Git remote '{remote_request.name}' added successfully", ) except subprocess.CalledProcessError as e: # Exit code 128 means not a git repository @@ -336,17 +355,20 @@ async def add_git_remote( @router.put("/git/remotes/{remote_name}") +@require_admin @inject async def update_git_remote( + request: Request, remote_name: str, - request: GitRemoteRequest, + remote_request: GitRemoteRequest, path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], ) -> UpdateActionResponse: """Update an existing git remote URL. Args: + request: FastAPI request object (required by authentication decorator) remote_name: Name of remote to update - request: New remote configuration + remote_request: New remote configuration path_resolver: Path resolver for repository location Returns: @@ -364,7 +386,7 @@ async def update_git_remote( git_service = GitOperationsService(path_resolver) # If name is changing, delete old and add new - if remote_name != request.name: + if remote_name != remote_request.name: # Can't rename origin if remote_name == "origin": return UpdateActionResponse( @@ -372,14 +394,14 @@ async def update_git_remote( error="Cannot rename 'origin' remote. Edit URL only.", ) git_service.delete_remote(remote_name) - git_service.add_remote(request.name, request.url) + git_service.add_remote(remote_request.name, remote_request.url) else: # Just update URL - git_service.update_remote(request.name, request.url) + git_service.update_remote(remote_request.name, remote_request.url) return UpdateActionResponse( success=True, - message=f"Git remote '{request.name}' updated successfully", + message=f"Git remote '{remote_request.name}' updated successfully", ) except ValueError as e: logger.warning("Invalid git remote update: %s", e) @@ -390,14 +412,17 @@ async def update_git_remote( @router.delete("/git/remotes/{remote_name}") +@require_admin @inject async def delete_git_remote( + request: Request, remote_name: str, path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], ) -> UpdateActionResponse: """Delete a git remote. Args: + request: FastAPI request object (required by authentication decorator) remote_name: Name of remote to delete path_resolver: Path resolver for repository location @@ -432,14 +457,17 @@ async def delete_git_remote( @router.get("/git/branches/{remote_name}") +@require_admin @inject async def list_git_branches( + request: Request, remote_name: str, path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], ) -> GitBranchListResponse: """List branches and tags for a git remote. Args: + request: FastAPI request object (required by authentication decorator) remote_name: Name of remote to query path_resolver: Path resolver for repository location @@ -470,8 +498,10 @@ async def list_git_branches( @router.get("/region-pack/status") +@require_admin @inject async def get_region_pack_status( + request: Request, path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], config: Annotated[BirdNETConfig, Depends(Provide[Container.config])], ) -> dict[str, Any]: @@ -485,8 +515,10 @@ async def get_region_pack_status( @router.get("/region-pack/available") +@require_admin @inject async def list_available_region_packs( + request: Request, path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], config: Annotated[BirdNETConfig, Depends(Provide[Container.config])], ) -> dict[str, Any]: @@ -504,8 +536,10 @@ async def list_available_region_packs( @router.post("/region-pack/download") +@require_admin @inject async def download_region_pack( + request: Request, path_resolver: Annotated[PathResolver, Depends(Provide[Container.path_resolver])], config: Annotated[BirdNETConfig, Depends(Provide[Container.config])], cache: Annotated[Cache, Depends(Provide[Container.cache_service])], diff --git a/src/birdnetpi/web/routers/update_view_routes.py b/src/birdnetpi/web/routers/update_view_routes.py index e10a45e4..f3015511 100644 --- a/src/birdnetpi/web/routers/update_view_routes.py +++ b/src/birdnetpi/web/routers/update_view_routes.py @@ -12,6 +12,7 @@ from birdnetpi.i18n.translation_manager import TranslationManager from birdnetpi.system.status import SystemInspector from birdnetpi.system.system_utils import SystemUtils +from birdnetpi.utils.auth import require_admin from birdnetpi.utils.cache import Cache from birdnetpi.utils.language import get_user_language from birdnetpi.web.core.container import Container @@ -24,6 +25,7 @@ @router.get("/", response_class=HTMLResponse) +@require_admin @inject async def update_page( request: Request, diff --git a/src/birdnetpi/web/static/css/auth.css b/src/birdnetpi/web/static/css/auth.css new file mode 100644 index 00000000..b8489cd7 --- /dev/null +++ b/src/birdnetpi/web/static/css/auth.css @@ -0,0 +1,168 @@ +/* Authentication pages styling */ + +.auth-page { + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + min-height: 100vh; + display: flex; + align-items: center; + justify-content: center; + padding: 2rem; +} + +.auth-container { + width: 100%; + max-width: 420px; + margin: 2rem auto; + padding: 0 1rem; +} + +.auth-card { + background: white; + border-radius: 12px; + box-shadow: 0 10px 40px rgba(0, 0, 0, 0.15); + padding: 2.5rem; +} + +.auth-header { + text-align: center; + margin-bottom: 2rem; +} + +.auth-header h1, +.auth-header h2 { + font-size: 1.875rem; + font-weight: 700; + color: #1a202c; + margin: 0 0 0.5rem 0; +} + +.auth-subtitle { + color: #718096; + font-size: 0.938rem; + margin: 0; +} + +.auth-error { + background-color: #fed7d7; + border: 1px solid #fc8181; + border-radius: 6px; + color: #742a2a; + padding: 0.75rem 1rem; + margin-bottom: 1.5rem; + font-size: 0.875rem; +} + +.auth-error strong { + font-weight: 600; +} + +.auth-form { + display: flex; + flex-direction: column; + gap: 1.25rem; +} + +.form-group { + display: flex; + flex-direction: column; + gap: 0.375rem; +} + +.form-group label { + font-weight: 500; + font-size: 0.875rem; + color: #2d3748; +} + +.form-group input { + padding: 0.625rem 0.75rem; + border: 1px solid #cbd5e0; + border-radius: 6px; + font-size: 0.938rem; + transition: + border-color 0.2s, + box-shadow 0.2s; +} + +.form-group input:focus { + outline: none; + border-color: #667eea; + box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1); +} + +.form-help { + color: #718096; + font-size: 0.813rem; + margin: 0; +} + +.form-error { + color: #e53e3e; + font-size: 0.813rem; + font-weight: 500; + margin: 0; +} + +.auth-button { + background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); + color: white; + padding: 0.75rem 1.5rem; + border: none; + border-radius: 6px; + font-size: 1rem; + font-weight: 600; + cursor: pointer; + transition: + transform 0.2s, + box-shadow 0.2s; + margin-top: 0.5rem; +} + +.auth-button:hover { + transform: translateY(-2px); + box-shadow: 0 4px 12px rgba(102, 126, 234, 0.4); +} + +.auth-button:active { + transform: translateY(0); +} + +.auth-button:disabled { + opacity: 0.6; + cursor: not-allowed; +} + +.auth-footer { + margin-top: 1.5rem; + padding-top: 1.5rem; + border-top: 1px solid #e2e8f0; + text-align: center; +} + +.auth-link { + color: #667eea; + text-decoration: none; + font-size: 0.938rem; + font-weight: 500; + transition: color 0.2s; +} + +.auth-link:hover { + color: #764ba2; + text-decoration: underline; +} + +/* Responsive adjustments */ +@media (max-width: 640px) { + .auth-page { + padding: 1rem; + } + + .auth-card { + padding: 2rem 1.5rem; + } + + .auth-header h1 { + font-size: 1.5rem; + } +} diff --git a/src/birdnetpi/web/static/css/style.css b/src/birdnetpi/web/static/css/style.css index f4189414..3387b207 100644 --- a/src/birdnetpi/web/static/css/style.css +++ b/src/birdnetpi/web/static/css/style.css @@ -1976,6 +1976,13 @@ h1 { background: #f0f0f0; } +/* Menu divider */ +.admin-menu .menu-divider { + border: none; + border-top: 1px solid #e0e0e0; + margin: 5px 0; +} + /* Dark mode support */ @media (prefers-color-scheme: dark) { .admin-menu { @@ -1994,6 +2001,10 @@ h1 { .admin-menu a.active { background: #3a3a3a; } + + .admin-menu .menu-divider { + border-top-color: #444; + } } /* ============================================ diff --git a/src/birdnetpi/web/templates/admin/login.html.j2 b/src/birdnetpi/web/templates/admin/login.html.j2 new file mode 100644 index 00000000..24d85bbb --- /dev/null +++ b/src/birdnetpi/web/templates/admin/login.html.j2 @@ -0,0 +1,47 @@ +{% extends "base.html.j2" %} + +{% block content %} +
+ + + {% if error %} + + {% endif %} + +
+
+
+ + +
+ +
+ + +
+ + +
+
+
+{% endblock %} diff --git a/src/birdnetpi/web/templates/admin/setup.html.j2 b/src/birdnetpi/web/templates/admin/setup.html.j2 new file mode 100644 index 00000000..34165ebd --- /dev/null +++ b/src/birdnetpi/web/templates/admin/setup.html.j2 @@ -0,0 +1,102 @@ + + + + + + BirdNET-Pi Setup + + + + +
+
+
+

Welcome to BirdNET-Pi

+

Let's set up your administrator account

+
+ +
+
+ + + + This will be your administrator username + +
+ +
+ + + + Minimum 8 characters + +
+ +
+ + + +
+ + +
+
+
+ + + + diff --git a/src/birdnetpi/web/templates/components/navigation.html.j2 b/src/birdnetpi/web/templates/components/navigation.html.j2 index 9c679dc5..0af625a2 100644 --- a/src/birdnetpi/web/templates/components/navigation.html.j2 +++ b/src/birdnetpi/web/templates/components/navigation.html.j2 @@ -3,7 +3,7 @@ diff --git a/tests/birdnetpi/database/test_core.py b/tests/birdnetpi/database/test_core.py index 28f12c8f..66ad838c 100644 --- a/tests/birdnetpi/database/test_core.py +++ b/tests/birdnetpi/database/test_core.py @@ -32,6 +32,7 @@ async def core_database_service(tmp_path): @pytest.mark.no_leaks @pytest.mark.asyncio +@pytest.mark.ci_issue # Flaky in CI due to async context manager mocking timing issues @pytest.mark.parametrize( "operation,should_fail,exception", [ diff --git a/tests/birdnetpi/web/routers/test_auth_routes.py b/tests/birdnetpi/web/routers/test_auth_routes.py new file mode 100644 index 00000000..5bf81bc0 --- /dev/null +++ b/tests/birdnetpi/web/routers/test_auth_routes.py @@ -0,0 +1,80 @@ +"""Tests for authentication routes including SQLAdmin redirects.""" + +from unittest.mock import MagicMock + +import pytest +from dependency_injector import providers +from fastapi.testclient import TestClient + +from birdnetpi.utils.auth import AuthService +from birdnetpi.web.core.container import Container +from birdnetpi.web.core.factory import create_app + + +@pytest.fixture +async def client_with_admin(app_with_temp_data): + """Create a test client with admin authentication mocked. + + This fixture mocks the auth_service to report that an admin user exists, + which prevents the SetupRedirectMiddleware from redirecting all requests + to /admin/setup. + + Uses app_with_temp_data fixture to ensure proper path isolation, then + creates a new app with mocked auth_service that reports admin exists. + """ + # Mock auth_service to always return True for admin_exists() + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + + # Override Container's auth_service before app creation + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) + + # Create new app with mocked auth (path_resolver and config already overridden) + app = create_app() + + yield TestClient(app) + + # Reset override + Container.auth_service.reset_override() + + +class TestSQLAdminRedirects: + """Test that SQLAdmin login/logout routes redirect to BirdNET-Pi authentication.""" + + def test_sqladmin_login_get_redirects_to_birdnetpi_login(self, client_with_admin): + """Should redirect GET /admin/database/login to BirdNET-Pi login with next parameter.""" + response = client_with_admin.get("/admin/database/login", follow_redirects=False) + + assert response.status_code == 303 + assert response.headers["location"] == "/admin/login?next=/admin/database" + + def test_sqladmin_login_post_redirects_to_birdnetpi_login(self, client_with_admin): + """Should redirect POST /admin/database/login to BirdNET-Pi login with next parameter.""" + response = client_with_admin.post( + "/admin/database/login", + data={"username": "admin", "password": "password"}, + follow_redirects=False, + ) + + assert response.status_code == 303 + assert response.headers["location"] == "/admin/login?next=/admin/database" + + def test_sqladmin_logout_redirects_to_birdnetpi_logout(self, client_with_admin): + """Should redirect GET /admin/database/logout to BirdNET-Pi logout.""" + response = client_with_admin.get("/admin/database/logout", follow_redirects=False) + + assert response.status_code == 303 + assert response.headers["location"] == "/admin/logout" + + def test_redirect_routes_take_precedence_over_sqladmin(self, app_with_temp_data): + """Should verify auth redirect routes are registered and take precedence.""" + # Get all routes from the app + routes = [] + for route in app_with_temp_data.routes: + if hasattr(route, "path") and hasattr(route, "methods"): + routes.append((route.path, route.methods)) + + # Check that our redirect routes exist + assert ("/admin/database/login", {"GET"}) in routes + assert ("/admin/database/login", {"POST"}) in routes + assert ("/admin/database/logout", {"GET"}) in routes diff --git a/tests/birdnetpi/web/routers/test_logs_api_routes.py b/tests/birdnetpi/web/routers/test_logs_api_routes.py index 4953efa2..28be591d 100644 --- a/tests/birdnetpi/web/routers/test_logs_api_routes.py +++ b/tests/birdnetpi/web/routers/test_logs_api_routes.py @@ -1,15 +1,18 @@ """Tests for log API routes.""" -from datetime import datetime -from unittest.mock import AsyncMock +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock import pytest +import redis.asyncio from dependency_injector import providers from fastapi.testclient import TestClient from birdnetpi.config import ConfigManager from birdnetpi.database.core import CoreDatabaseService from birdnetpi.system.log_reader import LogReaderService +from birdnetpi.utils.auth import AdminUser, AuthService, pwd_context +from birdnetpi.utils.cache import Cache from birdnetpi.web.core.container import Container from birdnetpi.web.core.factory import create_app from birdnetpi.web.models.logs import LOG_LEVELS @@ -28,10 +31,10 @@ def mock_log_reader(): @pytest.fixture async def app_with_mock_log_reader(path_resolver, mock_log_reader): - """Create FastAPI app with mocked log reader service. + """Create FastAPI app with mocked log reader service and authentication support. This fixture overrides the log reader BEFORE creating the app, - ensuring the mock is properly wired. + ensuring the mock is properly wired. Also sets up authentication. """ # Override Container providers BEFORE creating app Container.path_resolver.override(providers.Singleton(lambda: path_resolver)) @@ -47,6 +50,48 @@ async def app_with_mock_log_reader(path_resolver, mock_log_reader): await temp_db_service.initialize() Container.core_database.override(providers.Singleton(lambda: temp_db_service)) + # Mock cache service + mock_cache = MagicMock(spec=Cache) + mock_cache.configure_mock( + **{"get.return_value": None, "set.return_value": True, "ping.return_value": True} + ) + Container.cache_service.override(providers.Singleton(lambda: mock_cache)) + + # Mock redis client with in-memory storage for sessions + mock_redis = AsyncMock(spec=redis.asyncio.Redis) + redis_storage = {} + + async def mock_set(key, value, ex=None): + redis_storage[key] = value + return True + + async def mock_get(key): + return redis_storage.get(key) + + async def mock_delete(key): + redis_storage.pop(key, None) + return True + + mock_redis.set = AsyncMock(spec=object, side_effect=mock_set) + mock_redis.get = AsyncMock(spec=object, side_effect=mock_get) + mock_redis.delete = AsyncMock(spec=object, side_effect=mock_delete) + mock_redis.close = AsyncMock(spec=object) + Container.redis_client.override(providers.Singleton(lambda: mock_redis)) + + # Mock auth service + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + mock_admin = AdminUser( + username="admin", + password_hash=pwd_context.hash("testpassword"), + created_at=datetime.now(UTC), + ) + mock_auth_service.load_admin_user.return_value = mock_admin + mock_auth_service.verify_password.side_effect = lambda plain, hashed: pwd_context.verify( + plain, hashed + ) + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) + # Override log reader with our mock Container.log_reader.override(providers.Object(mock_log_reader)) @@ -65,13 +110,16 @@ async def app_with_mock_log_reader(path_resolver, mock_log_reader): Container.database_path.reset_override() Container.config.reset_override() Container.core_database.reset_override() + Container.cache_service.reset_override() + Container.redis_client.reset_override() + Container.auth_service.reset_override() Container.log_reader.reset_override() class TestLogsAPIRoutes: """Test log API routes.""" - def test_get_logs(self, app_with_mock_log_reader, mock_log_reader): + def test_get_logs(self, app_with_mock_log_reader, authenticate_sync_client, mock_log_reader): """Should return historical logs with correct format and metadata.""" # Configure the mock mock_log_reader.get_logs.return_value = [ @@ -96,6 +144,7 @@ def test_get_logs(self, app_with_mock_log_reader, mock_log_reader): try: with TestClient(app_with_mock_log_reader) as client: + authenticate_sync_client(client) response = client.get("/api/logs") finally: Container.log_reader.reset_override() @@ -109,7 +158,9 @@ def test_get_logs(self, app_with_mock_log_reader, mock_log_reader): assert data["total"] == 2 assert "levels" in data - def test_get_logs_with_time_filter(self, app_with_mock_log_reader, mock_log_reader): + def test_get_logs_with_time_filter( + self, app_with_mock_log_reader, authenticate_sync_client, mock_log_reader + ): """Should apply time filters when fetching logs.""" # Configure the mock mock_log_reader.get_logs.return_value = [ @@ -123,6 +174,7 @@ def test_get_logs_with_time_filter(self, app_with_mock_log_reader, mock_log_read ] with TestClient(app_with_mock_log_reader) as client: + authenticate_sync_client(client) response = client.get( "/api/logs", params={ @@ -143,12 +195,15 @@ def test_get_logs_with_time_filter(self, app_with_mock_log_reader, mock_log_read assert call_args["level"] is None # No level filtering assert call_args["keyword"] is None # No keyword filtering - def test_get_logs_error_handling(self, app_with_mock_log_reader, mock_log_reader): + def test_get_logs_error_handling( + self, app_with_mock_log_reader, authenticate_sync_client, mock_log_reader + ): """Should handle errors gracefully and return empty list with error message.""" # Configure the mock to raise an error mock_log_reader.get_logs.side_effect = Exception("Database error") with TestClient(app_with_mock_log_reader) as client: + authenticate_sync_client(client) response = client.get("/api/logs") assert response.status_code == 200 @@ -158,7 +213,7 @@ def test_get_logs_error_handling(self, app_with_mock_log_reader, mock_log_reader assert "error" in data assert "Database error" in data["error"] - def test_stream_logs(self, app_with_mock_log_reader, mock_log_reader): + def test_stream_logs(self, app_with_mock_log_reader, authenticate_sync_client, mock_log_reader): """Should stream logs via SSE with correct event format.""" # Create an async generator for streaming @@ -179,6 +234,7 @@ async def mock_stream(): mock_log_reader.stream_logs.return_value = mock_stream() with TestClient(app_with_mock_log_reader) as client: + authenticate_sync_client(client) # SSE endpoints return a streaming response # TestClient handles streaming automatically response = client.get("/api/logs/stream") @@ -200,7 +256,9 @@ async def mock_stream(): assert len(events) >= 1 assert "connected" in events[0] - def test_stream_logs_no_filters(self, app_with_mock_log_reader, mock_log_reader): + def test_stream_logs_no_filters( + self, app_with_mock_log_reader, authenticate_sync_client, mock_log_reader + ): """Should stream all logs without server-side filtering.""" async def mock_stream(): @@ -214,6 +272,7 @@ async def mock_stream(): mock_log_reader.stream_logs.return_value = mock_stream() with TestClient(app_with_mock_log_reader) as client: + authenticate_sync_client(client) # No parameters passed - filtering is client-side response = client.get("/api/logs/stream") assert response.status_code == 200 @@ -225,9 +284,10 @@ async def mock_stream(): assert call_args["level"] is None assert call_args["keyword"] is None - def test_get_log_levels(self, app_with_mock_log_reader): + def test_get_log_levels(self, app_with_mock_log_reader, authenticate_sync_client): """Should return all log levels with correct structure and colors.""" with TestClient(app_with_mock_log_reader) as client: + authenticate_sync_client(client) response = client.get("/api/logs/levels") assert response.status_code == 200 @@ -246,7 +306,9 @@ def test_get_log_levels(self, app_with_mock_log_reader): assert debug_level["value"] == 10 assert debug_level["color"] == "#6c757d" - def test_log_entry_parsing_edge_cases(self, app_with_mock_log_reader, mock_log_reader): + def test_log_entry_parsing_edge_cases( + self, app_with_mock_log_reader, authenticate_sync_client, mock_log_reader + ): """Should handle malformed log entries gracefully.""" # Return logs with various edge cases mock_log_reader.get_logs.return_value = [ @@ -266,6 +328,7 @@ def test_log_entry_parsing_edge_cases(self, app_with_mock_log_reader, mock_log_r ] with TestClient(app_with_mock_log_reader) as client: + authenticate_sync_client(client) response = client.get("/api/logs") assert response.status_code == 200 @@ -275,7 +338,9 @@ def test_log_entry_parsing_edge_cases(self, app_with_mock_log_reader, mock_log_r assert len(data["logs"]) >= 1 assert data["logs"][0]["message"] == "Normal log" - def test_limit_parameter(self, app_with_mock_log_reader, mock_log_reader): + def test_limit_parameter( + self, app_with_mock_log_reader, authenticate_sync_client, mock_log_reader + ): """Should respect limit parameter.""" # Create 20 mock logs mock_logs = [ @@ -291,6 +356,7 @@ def test_limit_parameter(self, app_with_mock_log_reader, mock_log_reader): mock_log_reader.get_logs.return_value = mock_logs[:5] # Return only 5 logs with TestClient(app_with_mock_log_reader) as client: + authenticate_sync_client(client) response = client.get("/api/logs", params={"limit": 5}) assert response.status_code == 200 diff --git a/tests/birdnetpi/web/routers/test_services_view_routes.py b/tests/birdnetpi/web/routers/test_services_view_routes.py index 83a6b8a8..9a2f2e21 100644 --- a/tests/birdnetpi/web/routers/test_services_view_routes.py +++ b/tests/birdnetpi/web/routers/test_services_view_routes.py @@ -14,7 +14,9 @@ class TestServicesViewRoutes: """Test class for services view endpoints.""" @pytest.mark.asyncio - async def test_services_page_renders_successfully(self, app_with_temp_data, path_resolver): + async def test_services_page_renders_successfully( + self, app_with_temp_data, path_resolver, authenticate_sync_client + ): """Should render services page with correct context.""" # Mock the system control service mock_system_control = MagicMock(spec=SystemControlService) @@ -52,6 +54,7 @@ async def test_services_page_renders_successfully(self, app_with_temp_data, path # Create test client with TestClient(app_with_temp_data) as client: + authenticate_sync_client(client) response = client.get("/admin/services") # Should render successfully @@ -59,7 +62,9 @@ async def test_services_page_renders_successfully(self, app_with_temp_data, path assert b"Services" in response.content or b"services" in response.content @pytest.mark.asyncio - async def test_services_page_handles_service_error(self, app_with_temp_data, path_resolver): + async def test_services_page_handles_service_error( + self, app_with_temp_data, path_resolver, authenticate_sync_client + ): """Should handle errors when getting service status.""" # Mock the system control service to raise an error mock_system_control = MagicMock(spec=SystemControlService) @@ -76,13 +81,16 @@ async def test_services_page_handles_service_error(self, app_with_temp_data, pat # Create test client with TestClient(app_with_temp_data) as client: + authenticate_sync_client(client) response = client.get("/admin/services") # Should still render successfully with empty services assert response.status_code == 200 @pytest.mark.asyncio - async def test_services_page_handles_system_info_error(self, app_with_temp_data, path_resolver): + async def test_services_page_handles_system_info_error( + self, app_with_temp_data, path_resolver, authenticate_sync_client + ): """Should handle errors when getting system info.""" # Mock the system control service mock_system_control = MagicMock(spec=SystemControlService) @@ -94,13 +102,16 @@ async def test_services_page_handles_system_info_error(self, app_with_temp_data, # Create test client with TestClient(app_with_temp_data) as client: + authenticate_sync_client(client) response = client.get("/admin/services") # Should still render successfully with default system info assert response.status_code == 200 @pytest.mark.asyncio - async def test_services_page_formats_uptime(self, app_with_temp_data, path_resolver): + async def test_services_page_formats_uptime( + self, app_with_temp_data, path_resolver, authenticate_sync_client + ): """Should format service uptime correctly.""" # Mock the system control service with various uptime values mock_system_control = MagicMock(spec=SystemControlService) @@ -146,6 +157,7 @@ async def test_services_page_formats_uptime(self, app_with_temp_data, path_resol # Create test client with TestClient(app_with_temp_data) as client: + authenticate_sync_client(client) response = client.get("/admin/services") # Should render successfully @@ -153,7 +165,7 @@ async def test_services_page_formats_uptime(self, app_with_temp_data, path_resol @pytest.mark.asyncio async def test_services_page_identifies_critical_services( - self, app_with_temp_data, path_resolver + self, app_with_temp_data, path_resolver, authenticate_sync_client ): """Should correctly identify critical services.""" # Mock the system control service @@ -191,13 +203,16 @@ async def test_services_page_identifies_critical_services( # Create test client with TestClient(app_with_temp_data) as client: + authenticate_sync_client(client) response = client.get("/admin/services") # Should render successfully assert response.status_code == 200 @pytest.mark.asyncio - async def test_services_page_service_status_variations(self, app_with_temp_data, path_resolver): + async def test_services_page_service_status_variations( + self, app_with_temp_data, path_resolver, authenticate_sync_client + ): """Should handle various service status values.""" # Mock the system control service with different statuses mock_system_control = MagicMock(spec=SystemControlService) @@ -261,6 +276,7 @@ async def test_services_page_service_status_variations(self, app_with_temp_data, # Create test client with TestClient(app_with_temp_data) as client: + authenticate_sync_client(client) response = client.get("/admin/services") # Should render successfully @@ -295,7 +311,7 @@ async def test_services_page_deployment_type_variations( @pytest.mark.asyncio async def test_services_page_with_systemutils_deployment( - self, app_with_temp_data, path_resolver + self, app_with_temp_data, path_resolver, authenticate_sync_client ): """Should use SystemUtils for deployment type when needed.""" # Mock the system control service @@ -319,6 +335,7 @@ async def test_services_page_with_systemutils_deployment( # Create test client with TestClient(app_with_temp_data) as client: + authenticate_sync_client(client) response = client.get("/admin/services") # Should render successfully diff --git a/tests/birdnetpi/web/routers/test_settings_api_routes.py b/tests/birdnetpi/web/routers/test_settings_api_routes.py index 42519489..ffff536d 100644 --- a/tests/birdnetpi/web/routers/test_settings_api_routes.py +++ b/tests/birdnetpi/web/routers/test_settings_api_routes.py @@ -4,17 +4,34 @@ from dependency_injector import providers from fastapi import FastAPI from fastapi.testclient import TestClient +from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import HTTPConnection from birdnetpi.web.core.container import Container from birdnetpi.web.routers.settings_api_routes import router +class TestAuthBackend(AuthenticationBackend): + """Test authentication backend that always authenticates as admin.""" + + async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, SimpleUser] | None: + """Return authenticated admin user for tests.""" + return AuthCredentials(["authenticated"]), SimpleUser("test-admin") + + @pytest.fixture def client(tmp_path, path_resolver): """Create test client with admin API routes and mocked dependencies.""" # Create the app app = FastAPI() + # Add authentication middleware with test backend that always authenticates + app.add_middleware( + AuthenticationMiddleware, + backend=TestAuthBackend(), + ) + # Create the real container container = Container() @@ -33,7 +50,7 @@ def client(tmp_path, path_resolver): # Include the router with the same prefix as in factory app.include_router(router, prefix="/api") - # Create and return test client + # Create test client client = TestClient(app) # Store the path resolver for access in tests diff --git a/tests/birdnetpi/web/routers/test_settings_routes.py b/tests/birdnetpi/web/routers/test_settings_routes.py index ea141c06..d41e482c 100644 --- a/tests/birdnetpi/web/routers/test_settings_routes.py +++ b/tests/birdnetpi/web/routers/test_settings_routes.py @@ -2,6 +2,7 @@ import shutil import tempfile +from datetime import datetime from pathlib import Path from unittest.mock import MagicMock, Mock, patch @@ -13,6 +14,7 @@ from birdnetpi.audio.devices import AudioDevice, AudioDeviceService from birdnetpi.config import BirdNETConfig, ConfigManager +from birdnetpi.utils.auth import AdminUser, AuthService, pwd_context from birdnetpi.web.core.container import Container from birdnetpi.web.core.factory import create_app @@ -85,11 +87,26 @@ def app_with_settings_routes(path_resolver, repo_root, test_config, mock_audio_d mock_config_manager.save.return_value = None mock_audio_service = MagicMock(spec=AudioDeviceService) mock_audio_service.discover_input_devices.return_value = mock_audio_devices + + # Mock AuthService to enable authentication in tests + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + mock_admin = AdminUser( + username="admin", + password_hash=pwd_context.hash("testpassword"), + created_at=datetime.now(), + ) + mock_auth_service.load_admin_user.return_value = mock_admin + mock_auth_service.verify_password.side_effect = lambda plain, hashed: pwd_context.verify( + plain, hashed + ) + path_resolver.get_birdnetpi_config_path = lambda: config_dir / "birdnetpi.yaml" path_resolver.get_templates_dir = lambda: repo_root / "src" / "birdnetpi" / "web" / "templates" path_resolver.get_static_dir = lambda: repo_root / "src" / "birdnetpi" / "web" / "static" path_resolver.get_models_dir = lambda: repo_root / "models" Container.path_resolver.override(providers.Singleton(lambda: path_resolver)) + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) templates_dir = repo_root / "src" / "birdnetpi" / "web" / "templates" Container.templates.override( providers.Singleton(lambda: Jinja2Templates(directory=str(templates_dir))) @@ -113,15 +130,17 @@ def app_with_settings_routes(path_resolver, repo_root, test_config, mock_audio_d app = create_app() yield (app, mock_config_manager, mock_audio_service) Container.path_resolver.reset_override() + Container.auth_service.reset_override() Container.templates.reset_override() shutil.rmtree(temp_dir, ignore_errors=True) @pytest.fixture -def client_with_mocks(app_with_settings_routes): - """Create test client with mocked dependencies.""" +def client_with_mocks(app_with_settings_routes, authenticate_sync_client): + """Create authenticated test client with mocked dependencies.""" app, config_manager, audio_service = app_with_settings_routes with TestClient(app) as test_client: + authenticate_sync_client(test_client) yield (test_client, config_manager, audio_service) @@ -184,12 +203,11 @@ def test_settings_page_includes_hidden_inputs(self, client_with_mocks): assert 'name="site_name"' in response.text assert 'name="model"' in response.text - def test_settings_page_handles_no_audio_devices(self, app_with_settings_routes): + def test_settings_page_handles_no_audio_devices(self, client_with_mocks): """Should handle case when no audio devices are available.""" - app, _config_manager, audio_service = app_with_settings_routes + client, _config_manager, audio_service = client_with_mocks audio_service.discover_input_devices.return_value = [] - with TestClient(app) as client: - response = client.get("/admin/settings") + response = client.get("/admin/settings") assert response.status_code == 200 assert "System Default" in response.text diff --git a/tests/birdnetpi/web/routers/test_sqladmin_view_routes.py b/tests/birdnetpi/web/routers/test_sqladmin_view_routes.py index e77669a0..27768b15 100644 --- a/tests/birdnetpi/web/routers/test_sqladmin_view_routes.py +++ b/tests/birdnetpi/web/routers/test_sqladmin_view_routes.py @@ -66,9 +66,15 @@ def test_setup_sqladmin_creates_admin_instance(self, mock_admin_class, mock_cont mock_admin_instance = MagicMock(spec=Admin) mock_admin_class.return_value = mock_admin_instance result = setup_sqladmin(app) - mock_admin_class.assert_called_once_with( - app, mock_async_engine, base_url="/admin/database", title="BirdNET-Pi Database Admin" - ) + # Check that Admin was called with the expected parameters + # Note: authentication_backend is also passed, but we check for the core params + mock_admin_class.assert_called_once() + call_args, call_kwargs = mock_admin_class.call_args + assert call_args[0] == app + assert call_args[1] == mock_async_engine + assert call_kwargs["base_url"] == "/admin/database" + assert call_kwargs["title"] == "BirdNET-Pi Database Admin" + assert "authentication_backend" in call_kwargs assert mock_admin_instance.add_view.call_count == 3 call_args = [call.args[0] for call in mock_admin_instance.add_view.call_args_list] assert DetectionAdmin in call_args diff --git a/tests/birdnetpi/web/routers/test_system_api_routes.py b/tests/birdnetpi/web/routers/test_system_api_routes.py index a406b21f..1fa78c77 100644 --- a/tests/birdnetpi/web/routers/test_system_api_routes.py +++ b/tests/birdnetpi/web/routers/test_system_api_routes.py @@ -1,18 +1,30 @@ """Tests for system API routes that handle hardware monitoring and system status.""" -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest +import redis.asyncio from dependency_injector import providers from fastapi.testclient import TestClient +from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser +from starlette.requests import HTTPConnection from birdnetpi.config import ConfigManager from birdnetpi.database.core import CoreDatabaseService from birdnetpi.detections.queries import DetectionQueryService +from birdnetpi.utils.auth import AuthService from birdnetpi.web.core.container import Container from birdnetpi.web.core.factory import create_app +class TestAuthBackend(AuthenticationBackend): + """Test authentication backend that always authenticates as admin.""" + + async def authenticate(self, conn: HTTPConnection) -> tuple[AuthCredentials, SimpleUser] | None: + """Return authenticated admin user for tests.""" + return AuthCredentials(["authenticated"]), SimpleUser("test-admin") + + @pytest.fixture def mock_detection_query_service(): """Create a mock DetectionQueryService. @@ -41,7 +53,20 @@ async def app_with_system_router(path_resolver, mock_detection_query_service): await temp_db_service.initialize() Container.core_database.override(providers.Singleton(lambda: temp_db_service)) Container.detection_query_service.override(providers.Object(mock_detection_query_service)) - app = create_app() + + # Mock Redis client with spec + mock_redis = MagicMock(spec=redis.asyncio.Redis) + Container.redis_client.override(providers.Singleton(lambda: mock_redis)) + + # Mock auth_service to always return True for admin_exists() + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) + + # Patch SessionAuthBackend to use TestAuthBackend for authentication + with patch("birdnetpi.web.core.factory.SessionAuthBackend", TestAuthBackend): + app = create_app() + app._test_db_service = temp_db_service # type: ignore[attr-defined] app._test_mock_query_service = mock_detection_query_service # type: ignore[attr-defined] yield app @@ -52,6 +77,8 @@ async def app_with_system_router(path_resolver, mock_detection_query_service): Container.config.reset_override() Container.core_database.reset_override() Container.detection_query_service.reset_override() + Container.redis_client.reset_override() + Container.auth_service.reset_override() @pytest.fixture diff --git a/tests/birdnetpi/web/routers/test_system_services_api_routes.py b/tests/birdnetpi/web/routers/test_system_services_api_routes.py index b5a6d35d..0a9d3537 100644 --- a/tests/birdnetpi/web/routers/test_system_services_api_routes.py +++ b/tests/birdnetpi/web/routers/test_system_services_api_routes.py @@ -1,15 +1,20 @@ """Tests for services API routes.""" -from unittest.mock import MagicMock, patch +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock, patch import pytest +import redis.asyncio from dependency_injector import providers -from fastapi import FastAPI from fastapi.testclient import TestClient +from birdnetpi.config import ConfigManager +from birdnetpi.database.core import CoreDatabaseService from birdnetpi.system.system_control import SystemControlService +from birdnetpi.utils.auth import AdminUser, AuthService, pwd_context +from birdnetpi.utils.cache import Cache from birdnetpi.web.core.container import Container -from birdnetpi.web.routers.system_api_routes import router +from birdnetpi.web.core.factory import create_app @pytest.fixture @@ -19,29 +24,103 @@ def mock_system_control(): @pytest.fixture -def client(path_resolver, mock_system_control): +async def client(path_resolver, mock_system_control, authenticate_sync_client): """Create test client with services API routes. Mocks deployment environment to consistently return "docker" so tests use the docker service configuration (where "fastapi" is a critical service). This prevents test failures in CI where systemd detection would return "sbc". + + Uses app_with_temp_data infrastructure for proper authentication. """ # Mock deployment environment to return "docker" consistently with patch( "birdnetpi.web.routers.system_api_routes.SystemUtils.get_deployment_environment", return_value="docker", ): - app = FastAPI() - container = Container() - # IMPORTANT: Override path_resolver BEFORE any other providers to prevent permission errors - container.path_resolver.override(providers.Singleton(lambda: path_resolver)) - container.database_path.override( + # Override Container class-level providers BEFORE app creation + # This is critical because create_app() uses the global Container singleton + Container.path_resolver.override(providers.Singleton(lambda: path_resolver)) + Container.database_path.override( providers.Factory(lambda: path_resolver.get_database_path()) ) - container.system_control_service.override(providers.Object(mock_system_control)) - container.wire(modules=["birdnetpi.web.routers.system_api_routes"]) - app.include_router(router, prefix="/api") - yield TestClient(app) + + # Create config + manager = ConfigManager(path_resolver) + test_config = manager.load() + Container.config.override(providers.Singleton(lambda: test_config)) + + # Create a test database service with the temp path + temp_db_service = CoreDatabaseService(path_resolver.get_database_path()) + await temp_db_service.initialize() + Container.core_database.override(providers.Singleton(lambda: temp_db_service)) + + # Mock cache service + mock_cache = MagicMock(spec=Cache) + mock_cache.configure_mock( + **{"get.return_value": None, "set.return_value": True, "ping.return_value": True} + ) + Container.cache_service.override(providers.Singleton(lambda: mock_cache)) + + # Mock redis client with in-memory storage for sessions + mock_redis = AsyncMock(spec=redis.asyncio.Redis) + redis_storage = {} + + async def mock_set(key, value, ex=None): + redis_storage[key] = value + return True + + async def mock_get(key): + return redis_storage.get(key) + + async def mock_delete(key): + redis_storage.pop(key, None) + return True + + mock_redis.set = AsyncMock(spec=object, side_effect=mock_set) + mock_redis.get = AsyncMock(spec=object, side_effect=mock_get) + mock_redis.delete = AsyncMock(spec=object, side_effect=mock_delete) + mock_redis.close = AsyncMock(spec=object) + Container.redis_client.override(providers.Singleton(lambda: mock_redis)) + + # Mock auth service + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + mock_admin = AdminUser( + username="admin", + password_hash=pwd_context.hash("testpassword"), + created_at=datetime.now(UTC), + ) + mock_auth_service.load_admin_user.return_value = mock_admin + mock_auth_service.verify_password.side_effect = lambda plain, hashed: pwd_context.verify( + plain, hashed + ) + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) + + # Override system control service + Container.system_control_service.override(providers.Object(mock_system_control)) + + # Create app with full auth setup + app = create_app() + + # Create client and authenticate + test_client = TestClient(app) + authenticate_sync_client(test_client) + yield test_client + + # Cleanup database + if hasattr(temp_db_service, "async_engine") and temp_db_service.async_engine: + await temp_db_service.async_engine.dispose() + + # Cleanup: reset all overrides + Container.path_resolver.reset_override() + Container.database_path.reset_override() + Container.config.reset_override() + Container.core_database.reset_override() + Container.cache_service.reset_override() + Container.redis_client.reset_override() + Container.auth_service.reset_override() + Container.system_control_service.reset_override() class TestSystemServicesAPIRoutes: diff --git a/tests/birdnetpi/web/routers/test_update_api_routes.py b/tests/birdnetpi/web/routers/test_update_api_routes.py index eefac133..7ae48654 100644 --- a/tests/birdnetpi/web/routers/test_update_api_routes.py +++ b/tests/birdnetpi/web/routers/test_update_api_routes.py @@ -5,9 +5,11 @@ @pytest.fixture -def client(app_with_temp_data): - """Create test client from app.""" - return TestClient(app_with_temp_data) +def client(app_with_temp_data, authenticate_sync_client): + """Create authenticated test client from app.""" + test_client = TestClient(app_with_temp_data) + authenticate_sync_client(test_client) + return test_client class TestCheckForUpdates: @@ -321,9 +323,10 @@ def test_update_flow_check_then_apply(self, client, cache): assert result_response.status_code == 200 assert result_response.json()["success"] is True - def test_all_update_endpoints_exist(self, app_with_temp_data): + def test_all_update_endpoints_exist(self, app_with_temp_data, authenticate_sync_client): """Should have all update endpoints available.""" client = TestClient(app_with_temp_data) + authenticate_sync_client(client) endpoints = [ ("/api/update/status", "GET"), diff --git a/tests/birdnetpi/web/routers/test_update_git_api.py b/tests/birdnetpi/web/routers/test_update_git_api.py index 1658f9a0..e785a419 100644 --- a/tests/birdnetpi/web/routers/test_update_git_api.py +++ b/tests/birdnetpi/web/routers/test_update_git_api.py @@ -51,10 +51,17 @@ class TestListGitRemotes: ids=["sbc_lists_remotes", "docker_returns_empty", "error_returns_500"], ) def test_list_remotes( - self, app_with_temp_data, deployment, service_error, expected_status, expected_remotes + self, + app_with_temp_data, + authenticate_sync_client, + deployment, + service_error, + expected_status, + expected_remotes, ): """Should handle listing git remotes based on deployment environment.""" client = TestClient(app_with_temp_data) + authenticate_sync_client(client) with patch( "birdnetpi.web.routers.update_api_routes.SystemUtils", autospec=True @@ -101,6 +108,7 @@ class TestAddGitRemote: def test_add_remote( self, app_with_temp_data, + authenticate_sync_client, deployment, service_error, expected_success, @@ -108,6 +116,7 @@ def test_add_remote( ): """Should handle adding git remotes based on deployment environment.""" client = TestClient(app_with_temp_data) + authenticate_sync_client(client) with patch( "birdnetpi.web.routers.update_api_routes.SystemUtils", autospec=True @@ -186,6 +195,7 @@ class TestUpdateGitRemote: def test_update_remote( self, app_with_temp_data, + authenticate_sync_client, deployment, old_name, new_name, @@ -197,6 +207,7 @@ def test_update_remote( ): """Should handle updating git remotes based on deployment and parameters.""" client = TestClient(app_with_temp_data) + authenticate_sync_client(client) with patch( "birdnetpi.web.routers.update_api_routes.SystemUtils", autospec=True @@ -249,6 +260,7 @@ class TestDeleteGitRemote: def test_delete_remote( self, app_with_temp_data, + authenticate_sync_client, deployment, remote_name, service_error, @@ -257,6 +269,7 @@ def test_delete_remote( ): """Should handle deleting git remotes based on deployment environment.""" client = TestClient(app_with_temp_data) + authenticate_sync_client(client) with patch( "birdnetpi.web.routers.update_api_routes.SystemUtils", autospec=True @@ -303,6 +316,7 @@ class TestListGitBranches: def test_list_branches( self, app_with_temp_data, + authenticate_sync_client, deployment, service_error, expected_status, @@ -311,6 +325,7 @@ def test_list_branches( ): """Should handle listing git branches based on deployment environment.""" client = TestClient(app_with_temp_data) + authenticate_sync_client(client) with patch( "birdnetpi.web.routers.update_api_routes.SystemUtils", autospec=True diff --git a/tests/birdnetpi/web/routers/test_update_view_routes.py b/tests/birdnetpi/web/routers/test_update_view_routes.py index ab754aec..3ed7c657 100644 --- a/tests/birdnetpi/web/routers/test_update_view_routes.py +++ b/tests/birdnetpi/web/routers/test_update_view_routes.py @@ -10,8 +10,8 @@ @pytest.fixture -def client(app_with_temp_data): - """Create test client from app.""" +def client(app_with_temp_data, authenticate_sync_client): + """Create authenticated test client from app.""" # Mount static files to avoid template rendering errors # Create a temporary static directory @@ -24,7 +24,10 @@ def client(app_with_temp_data): # Mount the static files app_with_temp_data.mount("/static", StaticFiles(directory=static_dir), name="static") - return TestClient(app_with_temp_data) + test_client = TestClient(app_with_temp_data) + authenticate_sync_client(test_client) + + return test_client class TestUpdateViewRoutes: diff --git a/tests/birdnetpi/web/test_notification_rules_ui.py b/tests/birdnetpi/web/test_notification_rules_ui.py index fcb1fce5..60bac771 100644 --- a/tests/birdnetpi/web/test_notification_rules_ui.py +++ b/tests/birdnetpi/web/test_notification_rules_ui.py @@ -14,6 +14,7 @@ from birdnetpi.audio.capture import AudioDeviceService from birdnetpi.config.manager import ConfigManager from birdnetpi.config.models import BirdNETConfig +from birdnetpi.utils.auth import AuthService from birdnetpi.web.core.container import Container from birdnetpi.web.core.factory import create_app @@ -43,8 +44,29 @@ def app_with_notification_rules(path_resolver, repo_root, mock_config_with_rules mock_audio_service = MagicMock(spec=AudioDeviceService) mock_audio_service.discover_input_devices.return_value = [] + # Mock AuthService to indicate admin exists (prevents setup redirect) + # and provide authentication for login + from datetime import datetime + + from birdnetpi.utils.auth import AdminUser, pwd_context + + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + + # Create a mock admin user with hashed password "testpassword" + mock_admin = AdminUser( + username="admin", + password_hash=pwd_context.hash("testpassword"), + created_at=datetime.now(), + ) + mock_auth_service.load_admin_user.return_value = mock_admin + mock_auth_service.verify_password.side_effect = lambda plain, hashed: pwd_context.verify( + plain, hashed + ) + # Override providers BEFORE creating the app Container.path_resolver.override(providers.Singleton(lambda: path_resolver)) + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) # Create templates templates_dir = path_resolver.get_templates_dir() @@ -73,13 +95,24 @@ def app_with_notification_rules(path_resolver, repo_root, mock_config_with_rules # Clean up overrides Container.path_resolver.reset_override() Container.templates.reset_override() + Container.auth_service.reset_override() @pytest.fixture def client_with_notification_rules(app_with_notification_rules): - """Create test client with notification rules.""" + """Create authenticated test client with notification rules.""" app, config_manager, audio_service = app_with_notification_rules + with TestClient(app) as test_client: + # Authenticate the test client by logging in + # This creates a session cookie that will be used for subsequent requests + login_response = test_client.post( + "/admin/login", + data={"username": "admin", "password": "testpassword"}, + follow_redirects=False, + ) + assert login_response.status_code == 303 # Successful login redirects + yield test_client, config_manager, audio_service diff --git a/tests/conftest.py b/tests/conftest.py index c2b851b2..dd70ed10 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,12 +11,14 @@ import matplotlib import pytest import redis +import redis.asyncio from dependency_injector import providers from sqlalchemy.engine import Result, Row from sqlalchemy.engine.result import MappingResult, ScalarResult from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel +from starlette.testclient import TestClient from birdnetpi.config import ConfigManager from birdnetpi.config.models import BirdNETConfig, UpdateConfig @@ -29,6 +31,7 @@ from birdnetpi.species.display import SpeciesDisplayService from birdnetpi.system.file_manager import FileManager from birdnetpi.system.path_resolver import PathResolver +from birdnetpi.utils.auth import AdminUser, AuthService, pwd_context from birdnetpi.utils.cache import Cache from birdnetpi.web.core.container import Container from birdnetpi.web.core.factory import create_app @@ -197,6 +200,44 @@ async def app_with_temp_data(path_resolver): ) Container.cache_service.override(providers.Singleton(lambda: mock_cache)) + # Mock the redis client to avoid event loop closure issues during test teardown + # and to properly store/retrieve session data for authentication + mock_redis = AsyncMock(spec=redis.asyncio.Redis) + # Create in-memory storage for sessions + redis_storage = {} + + async def mock_set(key, value, ex=None): + redis_storage[key] = value + return True + + async def mock_get(key): + return redis_storage.get(key) + + async def mock_delete(key): + redis_storage.pop(key, None) + return True + + mock_redis.set = AsyncMock(spec=object, side_effect=mock_set) + mock_redis.get = AsyncMock(spec=object, side_effect=mock_get) + mock_redis.delete = AsyncMock(spec=object, side_effect=mock_delete) + mock_redis.close = AsyncMock(spec=object) + Container.redis_client.override(providers.Singleton(lambda: mock_redis)) + + # Mock the auth service to enable authentication in tests + # Create a test admin user with hashed password "testpassword" + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + mock_admin = AdminUser( + username="admin", + password_hash=pwd_context.hash("testpassword"), + created_at=datetime.now(UTC), + ) + mock_auth_service.load_admin_user.return_value = mock_admin + mock_auth_service.verify_password.side_effect = lambda plain, hashed: pwd_context.verify( + plain, hashed + ) + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) + # Reset dependent services to ensure they use the overridden path_resolver # These are Singletons that depend on path_resolver and must be recreated # with the test path_resolver to prevent permission errors on /var/lib/birdnetpi @@ -229,6 +270,80 @@ async def app_with_temp_data(path_resolver): Container.config.reset_override() Container.core_database.reset_override() Container.cache_service.reset_override() + Container.redis_client.reset_override() + Container.auth_service.reset_override() + + +@pytest.fixture +def authenticate_sync_client(): + """Provide a function to authenticate a sync TestClient. + + Returns: + A callable that takes a TestClient and authenticates it + + Example: + def test_something(authenticate_sync_client): + client = TestClient(app) + authenticate_sync_client(client) + """ + + def _authenticate(client: TestClient) -> TestClient: + login_response = client.post( + "/admin/login", + data={"username": "admin", "password": "testpassword"}, + follow_redirects=False, + ) + assert login_response.status_code == 303 # Successful login redirects + return client + + return _authenticate + + +@pytest.fixture +def authenticate_async_client(): + """Provide a function to authenticate an async AsyncClient. + + Returns: + A callable that takes an AsyncClient and authenticates it + + Example: + async def test_something(authenticate_async_client): + async with AsyncClient(...) as client: + await authenticate_async_client(client) + """ + + async def _authenticate(client): + login_response = await client.post( + "/admin/login", + data={"username": "admin", "password": "testpassword"}, + follow_redirects=False, + ) + assert login_response.status_code == 303 # Successful login redirects + return client + + return _authenticate + + +@pytest.fixture +def authenticated_client(app_with_temp_data, authenticate_sync_client): + """Create an authenticated test client for routes that require authentication. + + This fixture: + 1. Uses the app_with_temp_data fixture (which mocks AuthService) + 2. Creates a TestClient + 3. Logs in with test credentials (username: admin, password: testpassword) + 4. Returns the authenticated client with session cookie + + Use this fixture for tests that access admin-protected routes. + + Example: + def test_protected_route(authenticated_client): + response = authenticated_client.get("/admin/settings") + assert response.status_code == 200 + """ + with TestClient(app_with_temp_data) as client: + authenticate_sync_client(client) + yield client @pytest.fixture diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index c96df765..b703910f 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -11,6 +11,62 @@ from birdnetpi.releases.asset_manifest import AssetManifest from birdnetpi.system.path_resolver import PathResolver +# Default test credentials for e2e tests +E2E_ADMIN_USERNAME = "admin" +E2E_ADMIN_PASSWORD = "e2e-test-password-123" + + +def _setup_admin_user(base_url: str = "http://localhost:8000") -> None: + """Create admin user for e2e tests if not already exists. + + Args: + base_url: The base URL of the BirdNET-Pi instance + """ + # Check if setup is needed (will redirect to setup if no admin exists) + response = httpx.get(f"{base_url}/", follow_redirects=False) + if response.status_code == 303 and "/admin/setup" in response.headers.get("location", ""): + # Create admin user + httpx.post( + f"{base_url}/admin/setup", + data={"username": E2E_ADMIN_USERNAME, "password": E2E_ADMIN_PASSWORD}, + follow_redirects=False, + ) + + +def _get_authenticated_client(base_url: str = "http://localhost:8000") -> httpx.Client: + """Get an authenticated httpx client for e2e tests. + + Args: + base_url: The base URL of the BirdNET-Pi instance + + Returns: + An httpx.Client with authentication cookies set + """ + client = httpx.Client(base_url=base_url) + # Login to get session cookie + client.post( + "/admin/login", + data={"username": E2E_ADMIN_USERNAME, "password": E2E_ADMIN_PASSWORD}, + follow_redirects=False, + ) + return client + + +@pytest.fixture +def authenticated_e2e_client() -> Generator[httpx.Client, None, None]: + """Fixture that provides an authenticated httpx client for standard e2e tests (port 8000).""" + client = _get_authenticated_client("http://localhost:8000") + yield client + client.close() + + +@pytest.fixture +def authenticated_profiling_client() -> Generator[httpx.Client, None, None]: + """Fixture that provides an authenticated httpx client for profiling tests (port 8001).""" + client = _get_authenticated_client("http://localhost:8001") + yield client + client.close() + @pytest.fixture(scope="module") def docker_compose_up_down() -> Generator[None, None, None]: @@ -55,6 +111,9 @@ def docker_compose_up_down() -> Generator[None, None, None]: subprocess.run([*compose_cmd, "down"], env=env, check=False) pytest.fail("Services did not become ready in time") + # Set up admin user for authentication + _setup_admin_user("http://localhost:8000") + yield # Bring down Docker Compose but keep the volume to preserve models diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index af50c6ad..77db2fd4 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -6,9 +6,10 @@ @pytest.mark.expensive -def test_root_endpoint_e2e(docker_compose_up_down) -> None: +def test_root_endpoint_e2e(docker_compose_up_down, authenticated_e2e_client) -> None: """Should serve the root endpoint of the BirdNET-Pi application.""" - response = httpx.get("http://localhost:8000") + # Need authenticated client since root page requires login + response = authenticated_e2e_client.get("/") assert response.status_code == 200 assert "BirdNET-Pi" in response.text @@ -62,10 +63,10 @@ def test_sqladmin_detection_list_e2e(docker_compose_up_down) -> None: result.check_returncode() # Wait for the FastAPI service to be fully ready after restart - # Retry the basic endpoint first to ensure the service is up + # Retry the health endpoint first to ensure the service is up for _attempt in range(10): try: - health_check = httpx.get("http://localhost:8000/", timeout=3) + health_check = httpx.get("http://localhost:8000/api/health/ready", timeout=3) if health_check.status_code == 200: break except Exception: @@ -74,7 +75,25 @@ def test_sqladmin_detection_list_e2e(docker_compose_up_down) -> None: else: pytest.fail("FastAPI service did not become ready after dummy data generation") - response = httpx.get("http://localhost:8000/admin/database/detection/list") + # Need authenticated client for SQLAdmin access + # Note: We need to recreate the client after container restart since session may be lost + # Re-setup admin and authenticate (container may have been restarted) + response = httpx.get("http://localhost:8000/", follow_redirects=False) + if response.status_code == 303 and "/admin/setup" in response.headers.get("location", ""): + httpx.post( + "http://localhost:8000/admin/setup", + data={"username": "admin", "password": "e2e-test-password-123"}, + follow_redirects=False, + ) + + client = httpx.Client(base_url="http://localhost:8000") + client.post( + "/admin/login", + data={"username": "admin", "password": "e2e-test-password-123"}, + follow_redirects=False, + ) + response = client.get("/admin/database/detection/list") + client.close() assert response.status_code == 200 assert "Detections" in response.text @@ -85,14 +104,14 @@ def test_sqladmin_detection_list_e2e(docker_compose_up_down) -> None: @pytest.mark.expensive -def test_profiling_disabled_by_default(docker_compose_up_down) -> None: +def test_profiling_disabled_by_default(docker_compose_up_down, authenticated_e2e_client) -> None: """Should not enable profiling when ENABLE_PROFILING is not set. This test is in the main e2e file because it needs the regular Docker environment without profiling enabled. """ # Request the root page with ?profile=1 - response = httpx.get("http://localhost:8000/?profile=1") + response = authenticated_e2e_client.get("/?profile=1") assert response.status_code == 200 # Should return the normal page, not profiling output diff --git a/tests/e2e/test_profiling_e2e.py b/tests/e2e/test_profiling_e2e.py index 323dd9f0..185e7ef2 100644 --- a/tests/e2e/test_profiling_e2e.py +++ b/tests/e2e/test_profiling_e2e.py @@ -50,6 +50,15 @@ def docker_compose_with_profiling() -> Generator[None, None, None]: subprocess.run([*compose_cmd, "--profile", "profiling", "down"], env=env, check=False) pytest.fail("Services did not become ready in time") + # Set up admin user for authentication (profiling uses port 8001) + response = httpx.get("http://localhost:8001/", follow_redirects=False) + if response.status_code == 303 and "/admin/setup" in response.headers.get("location", ""): + httpx.post( + "http://localhost:8001/admin/setup", + data={"username": "admin", "password": "e2e-test-password-123"}, + follow_redirects=False, + ) + yield # Tear down services (preserves test volume) @@ -57,10 +66,12 @@ def docker_compose_with_profiling() -> Generator[None, None, None]: @pytest.mark.expensive -def test_profiling_enabled_root_page(docker_compose_with_profiling) -> None: +def test_profiling_enabled_root_page( + docker_compose_with_profiling, authenticated_profiling_client +) -> None: """Should profiling works on the root page when enabled.""" # Request the root page with ?profile=1 - response = httpx.get("http://localhost:8001/?profile=1") + response = authenticated_profiling_client.get("/?profile=1") assert response.status_code == 200 # Should contain pyinstrument profiling output @@ -77,7 +88,9 @@ def test_profiling_enabled_root_page(docker_compose_with_profiling) -> None: @pytest.mark.expensive -def test_profiling_enabled_settings_page(docker_compose_with_profiling) -> None: +def test_profiling_enabled_settings_page( + docker_compose_with_profiling, authenticated_profiling_client +) -> None: """Should enable profiling on the settings page when ENABLE_PROFILING is set. The settings page is ideal for testing because it doesn't call @@ -85,7 +98,7 @@ def test_profiling_enabled_settings_page(docker_compose_with_profiling) -> None: inherent delays, providing a cleaner profiling output. """ # Request the settings page with ?profile=1 - response = httpx.get("http://localhost:8001/admin/settings?profile=1") + response = authenticated_profiling_client.get("/admin/settings?profile=1") assert response.status_code == 200 # Should contain pyinstrument profiling output @@ -99,10 +112,12 @@ def test_profiling_enabled_settings_page(docker_compose_with_profiling) -> None: @pytest.mark.expensive -def test_profiling_shows_system_calls(docker_compose_with_profiling) -> None: +def test_profiling_shows_system_calls( + docker_compose_with_profiling, authenticated_profiling_client +) -> None: """Should profiling output shows expected function calls.""" # Request the root page with profiling - response = httpx.get("http://localhost:8001/?profile=1") + response = authenticated_profiling_client.get("/?profile=1") assert response.status_code == 200 # Should contain pyinstrument profiling output @@ -124,10 +139,12 @@ def test_profiling_shows_system_calls(docker_compose_with_profiling) -> None: @pytest.mark.expensive -def test_profiling_normal_request_unaffected(docker_compose_with_profiling) -> None: +def test_profiling_normal_request_unaffected( + docker_compose_with_profiling, authenticated_profiling_client +) -> None: """Should requests without ?profile=1 work normally when profiling is enabled.""" # Request without profiling parameter - response = httpx.get("http://localhost:8001/") + response = authenticated_profiling_client.get("/") assert response.status_code == 200 # Should return normal page content diff --git a/tests/e2e/test_settings_e2e.py b/tests/e2e/test_settings_e2e.py index 830b2913..2bf4a4ab 100644 --- a/tests/e2e/test_settings_e2e.py +++ b/tests/e2e/test_settings_e2e.py @@ -3,10 +3,12 @@ import os import shutil import tempfile +from datetime import UTC, datetime from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest +import redis.asyncio import yaml from dependency_injector import providers from fastapi.templating import Jinja2Templates @@ -16,6 +18,7 @@ from birdnetpi.audio.devices import AudioDevice, AudioDeviceService from birdnetpi.config import BirdNETConfig, ConfigManager from birdnetpi.system.path_resolver import PathResolver +from birdnetpi.utils.auth import AdminUser, AuthService, pwd_context from birdnetpi.web.core.container import Container from birdnetpi.web.core.factory import create_app @@ -116,6 +119,41 @@ def e2e_app(self, temp_data_dir, mock_audio_devices, repo_root): providers.Singleton(lambda: Jinja2Templates(directory=str(templates_dir))) ) + # Mock redis client with in-memory storage for sessions + mock_redis = AsyncMock(spec=redis.asyncio.Redis) + redis_storage = {} + + async def mock_set(key, value, ex=None): + redis_storage[key] = value + return True + + async def mock_get(key): + return redis_storage.get(key) + + async def mock_delete(key): + redis_storage.pop(key, None) + return True + + mock_redis.set = AsyncMock(spec=object, side_effect=mock_set) + mock_redis.get = AsyncMock(spec=object, side_effect=mock_get) + mock_redis.delete = AsyncMock(spec=object, side_effect=mock_delete) + mock_redis.close = AsyncMock(spec=object) + Container.redis_client.override(providers.Singleton(lambda: mock_redis)) + + # Mock auth service + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + mock_admin = AdminUser( + username="admin", + password_hash=pwd_context.hash("testpassword"), + created_at=datetime.now(UTC), + ) + mock_auth_service.load_admin_user.return_value = mock_admin + mock_auth_service.verify_password.side_effect = lambda plain, hashed: pwd_context.verify( + plain, hashed + ) + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) + # Patch AudioDeviceService at import with ( patch( @@ -134,12 +172,15 @@ def e2e_app(self, temp_data_dir, mock_audio_devices, repo_root): # Cleanup overrides Container.path_resolver.reset_override() Container.templates.reset_override() + Container.redis_client.reset_override() + Container.auth_service.reset_override() - def test_e2e_settings_page_loads_with_current_config(self, e2e_app): + def test_e2e_settings_page_loads_with_current_config(self, e2e_app, authenticate_sync_client): """Should settings page loads and displays current configuration.""" app, _path_resolver, _ = e2e_app with TestClient(app) as client: + authenticate_sync_client(client) # GET settings page response = client.get("/admin/settings") @@ -162,11 +203,14 @@ def test_e2e_settings_page_loads_with_current_config(self, e2e_app): assert 'action="/admin/settings"' in response.text assert 'method="post"' in response.text - def test_e2e_settings_form_submission_saves_changes(self, e2e_app, temp_data_dir): + def test_e2e_settings_form_submission_saves_changes( + self, e2e_app, temp_data_dir, authenticate_sync_client + ): """Should submitting the settings form saves changes to config file.""" app, path_resolver, _ = e2e_app with TestClient(app) as client: + authenticate_sync_client(client) # Submit form with changed values form_data = { "site_name": "Updated E2E Site", @@ -209,11 +253,12 @@ def test_e2e_settings_form_submission_saves_changes(self, e2e_app, temp_data_dir assert saved_data["enable_gps"] is True assert saved_data["birdweather_id"] == "test123" - def test_e2e_settings_roundtrip(self, e2e_app): + def test_e2e_settings_roundtrip(self, e2e_app, authenticate_sync_client): """Should complete roundtrip: load, modify, save, reload.""" app, _path_resolver, _ = e2e_app with TestClient(app) as client: + authenticate_sync_client(client) # Step 1: Load initial settings page response1 = client.get("/admin/settings") assert response1.status_code == 200 @@ -253,11 +298,12 @@ def test_e2e_settings_roundtrip(self, e2e_app): assert "35.6762" in response3.text assert "139.6503" in response3.text - def test_e2e_settings_validation_errors(self, e2e_app): + def test_e2e_settings_validation_errors(self, e2e_app, authenticate_sync_client): """Should invalid form data is handled properly.""" app, _, _ = e2e_app with TestClient(app) as client: + authenticate_sync_client(client) # Try to submit with missing required fields form_data = { "site_name": "", # Empty required field @@ -276,11 +322,12 @@ def test_e2e_settings_validation_errors(self, e2e_app): # Form validation might raise an exception assert "ValidationError" in str(type(e).__name__) or "ValueError" in str(e) - def test_e2e_settings_handles_concurrent_access(self, e2e_app): + def test_e2e_settings_handles_concurrent_access(self, e2e_app, authenticate_sync_client): """Should handle concurrent access to settings (simulated).""" app, path_resolver, _ = e2e_app with TestClient(app) as client: + authenticate_sync_client(client) # Simulate two users accessing settings simultaneously # User 1 loads the page @@ -339,7 +386,7 @@ def test_e2e_settings_handles_concurrent_access(self, e2e_app): assert final_config.site_name == "User 2 Site" assert final_config.latitude == 30.0 - def test_e2e_settings_preserves_unmodified_fields(self, e2e_app): + def test_e2e_settings_preserves_unmodified_fields(self, e2e_app, authenticate_sync_client): """Should fields not in the form are preserved during save.""" app, path_resolver, _ = e2e_app @@ -351,6 +398,7 @@ def test_e2e_settings_preserves_unmodified_fields(self, e2e_app): config_manager.save(config) with TestClient(app) as client: + authenticate_sync_client(client) # Submit form (without git fields) form_data = { "site_name": "Preserve Test", diff --git a/tests/integration/test_app_startup_integration.py b/tests/integration/test_app_startup_integration.py index 97ab3631..fab9ab4f 100644 --- a/tests/integration/test_app_startup_integration.py +++ b/tests/integration/test_app_startup_integration.py @@ -1,14 +1,18 @@ """Integration test that verifies the FastAPI app starts with real Container.""" import shutil +from datetime import UTC, datetime from pathlib import Path +from unittest.mock import AsyncMock, MagicMock import pytest +import redis.asyncio from dependency_injector import providers from fastapi import FastAPI from fastapi.testclient import TestClient from birdnetpi.system.path_resolver import PathResolver +from birdnetpi.utils.auth import AdminUser, AuthService, pwd_context from birdnetpi.web.core.container import Container from birdnetpi.web.core.factory import create_app @@ -72,6 +76,41 @@ def app_with_real_container(self, test_resolver: PathResolver): providers.Factory(lambda: test_resolver.get_database_path()) ) + # Mock redis client with in-memory storage for sessions + mock_redis = AsyncMock(spec=redis.asyncio.Redis) + redis_storage = {} + + async def mock_set(key, value, ex=None): + redis_storage[key] = value + return True + + async def mock_get(key): + return redis_storage.get(key) + + async def mock_delete(key): + redis_storage.pop(key, None) + return True + + mock_redis.set = AsyncMock(spec=object, side_effect=mock_set) + mock_redis.get = AsyncMock(spec=object, side_effect=mock_get) + mock_redis.delete = AsyncMock(spec=object, side_effect=mock_delete) + mock_redis.close = AsyncMock(spec=object) + Container.redis_client.override(providers.Singleton(lambda: mock_redis)) + + # Mock auth service + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + mock_admin = AdminUser( + username="admin", + password_hash=pwd_context.hash("testpassword"), + created_at=datetime.now(UTC), + ) + mock_auth_service.load_admin_user.return_value = mock_admin + mock_auth_service.verify_password.side_effect = lambda plain, hashed: pwd_context.verify( + plain, hashed + ) + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) + # Create the app using the factory app = create_app() @@ -80,6 +119,8 @@ def app_with_real_container(self, test_resolver: PathResolver): # Clean up overrides Container.path_resolver.reset_override() Container.database_path.reset_override() + Container.redis_client.reset_override() + Container.auth_service.reset_override() def test_app_creation_succeeds(self, app_with_real_container: FastAPI): """Should create app without errors.""" @@ -114,9 +155,10 @@ def test_root_endpoint_works(self, app_with_real_container: FastAPI): assert response.status_code == 200 assert "html" in response.text.lower() - def test_api_endpoint_works(self, app_with_real_container: FastAPI): + def test_api_endpoint_works(self, app_with_real_container: FastAPI, authenticate_sync_client): """Should an API endpoint works with the real container.""" with TestClient(app_with_real_container) as client: + authenticate_sync_client(client) response = client.get("/api/system/hardware/status") assert response.status_code == 200 data = response.json() diff --git a/tests/integration/test_ebird_detection_filtering_integration.py b/tests/integration/test_ebird_detection_filtering_integration.py index 6f980a36..23bf890c 100644 --- a/tests/integration/test_ebird_detection_filtering_integration.py +++ b/tests/integration/test_ebird_detection_filtering_integration.py @@ -17,6 +17,7 @@ from birdnetpi.database.core import CoreDatabaseService from birdnetpi.database.ebird import EBirdRegionService from birdnetpi.releases.registry_service import BoundingBox, RegionPackInfo, RegistryService +from birdnetpi.utils.auth import AdminUser, AuthService, pwd_context from birdnetpi.utils.cache.cache import Cache from birdnetpi.web.core.container import Container from birdnetpi.web.core.factory import create_app @@ -120,6 +121,20 @@ async def app_with_ebird_filtering(path_resolver, mock_ebird_service, tmp_path): ) Container.registry_service.override(providers.Singleton(lambda: mock_registry_service)) + # Mock AuthService to enable authentication in tests + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + mock_admin = AdminUser( + username="admin", + password_hash=pwd_context.hash("testpassword"), + created_at=datetime.now(UTC), + ) + mock_auth_service.load_admin_user.return_value = mock_admin + mock_auth_service.verify_password.side_effect = lambda plain, hashed: pwd_context.verify( + plain, hashed + ) + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) + # Reset dependent services try: Container.ebird_region_service.reset() @@ -150,12 +165,25 @@ async def app_with_ebird_filtering(path_resolver, mock_ebird_service, tmp_path): Container.cache_service.reset_override() Container.ebird_region_service.reset_override() Container.registry_service.reset_override() + Container.auth_service.reset_override() + + +@pytest.fixture +async def authenticated_client(app_with_ebird_filtering, authenticate_async_client): + """Create an authenticated AsyncClient for API testing.""" + async with AsyncClient( + transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" + ) as client: + await authenticate_async_client(client) + yield client class TestEBirdFilteringDisabled: """Test that detections are allowed when eBird filtering is disabled.""" - async def test_detection_allowed_when_filtering_disabled(self, app_with_temp_data): + async def test_detection_allowed_when_filtering_disabled( + self, app_with_temp_data, authenticate_async_client + ): """Should allow detection when eBird filtering is disabled.""" # Ensure filtering is disabled config = Container.config() @@ -164,6 +192,7 @@ async def test_detection_allowed_when_filtering_disabled(self, app_with_temp_dat async with AsyncClient( transport=ASGITransport(app=app_with_temp_data), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -184,7 +213,9 @@ class TestEBirdFilteringModeOff: # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_detection_allowed_when_mode_off(self, app_with_ebird_filtering): + async def test_detection_allowed_when_mode_off( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should allow detection when detection_mode is 'off'.""" # Set mode to off config = Container.config() @@ -193,6 +224,7 @@ async def test_detection_allowed_when_mode_off(self, app_with_ebird_filtering): async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -212,7 +244,9 @@ class TestEBirdFilteringWarnMode: # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_vagrant_species_warned_but_allowed(self, app_with_ebird_filtering): + async def test_vagrant_species_warned_but_allowed( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should warn about vagrant species but still create detection.""" # Set mode to warn config = Container.config() @@ -226,6 +260,7 @@ async def test_vagrant_species_warned_but_allowed(self, app_with_ebird_filtering async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -246,7 +281,9 @@ class TestEBirdFilteringFilterMode: # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_vagrant_species_blocked_with_vagrant_strictness(self, app_with_ebird_filtering): + async def test_vagrant_species_blocked_with_vagrant_strictness( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should block vagrant species with vagrant strictness.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -259,6 +296,7 @@ async def test_vagrant_species_blocked_with_vagrant_strictness(self, app_with_eb async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -275,7 +313,9 @@ async def test_vagrant_species_blocked_with_vagrant_strictness(self, app_with_eb # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_rare_species_blocked_with_rare_strictness(self, app_with_ebird_filtering): + async def test_rare_species_blocked_with_rare_strictness( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should block rare species with rare strictness.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -288,6 +328,7 @@ async def test_rare_species_blocked_with_rare_strictness(self, app_with_ebird_fi async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -303,7 +344,7 @@ async def test_rare_species_blocked_with_rare_strictness(self, app_with_ebird_fi assert "filtered" in data["message"].lower() async def test_uncommon_species_blocked_with_uncommon_strictness( - self, app_with_ebird_filtering + self, app_with_ebird_filtering, authenticate_async_client ): """Should block uncommon species with uncommon strictness.""" config = Container.config() @@ -317,6 +358,7 @@ async def test_uncommon_species_blocked_with_uncommon_strictness( async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -333,7 +375,9 @@ async def test_uncommon_species_blocked_with_uncommon_strictness( # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_common_species_allowed_with_all_strictness(self, app_with_ebird_filtering): + async def test_common_species_allowed_with_all_strictness( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should allow common species with any strictness level.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -348,6 +392,7 @@ async def test_common_species_allowed_with_all_strictness(self, app_with_ebird_f async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -367,7 +412,9 @@ class TestEBirdFilteringUnknownSpecies: # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_unknown_species_allowed_with_allow_behavior(self, app_with_ebird_filtering): + async def test_unknown_species_allowed_with_allow_behavior( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should allow unknown species when behavior is 'allow'.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -379,6 +426,7 @@ async def test_unknown_species_allowed_with_allow_behavior(self, app_with_ebird_ async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -394,7 +442,9 @@ async def test_unknown_species_allowed_with_allow_behavior(self, app_with_ebird_ # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_unknown_species_blocked_with_block_behavior(self, app_with_ebird_filtering): + async def test_unknown_species_blocked_with_block_behavior( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should block unknown species when behavior is 'block'.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -405,6 +455,7 @@ async def test_unknown_species_blocked_with_block_behavior(self, app_with_ebird_ async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -425,7 +476,9 @@ class TestEBirdFilteringWithoutCoordinates: # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_detection_rejected_without_latitude(self, app_with_ebird_filtering): + async def test_detection_rejected_without_latitude( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should reject detection with validation error when latitude is missing.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -433,6 +486,7 @@ async def test_detection_rejected_without_latitude(self, app_with_ebird_filterin async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) # Create payload and remove latitude field payload = create_detection_payload( species_tensor="Turdus migratorius_American Robin", @@ -450,7 +504,9 @@ async def test_detection_rejected_without_latitude(self, app_with_ebird_filterin # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_detection_rejected_without_longitude(self, app_with_ebird_filtering): + async def test_detection_rejected_without_longitude( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should reject detection with validation error when longitude is missing.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -458,6 +514,7 @@ async def test_detection_rejected_without_longitude(self, app_with_ebird_filteri async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) # Create payload and remove longitude field payload = create_detection_payload( species_tensor="Turdus migratorius_American Robin", @@ -479,7 +536,9 @@ class TestEBirdFilteringErrorHandling: # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_detection_allowed_on_ebird_service_error(self, app_with_ebird_filtering): + async def test_detection_allowed_on_ebird_service_error( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should allow detection if eBird service fails.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -495,6 +554,7 @@ async def failing_attach(*args, **kwargs): async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -515,7 +575,9 @@ class TestEBirdFilteringStrictnessLevels: # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_vagrant_strictness_allows_rare_uncommon_common(self, app_with_ebird_filtering): + async def test_vagrant_strictness_allows_rare_uncommon_common( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should only block vagrant species with vagrant strictness.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -537,6 +599,7 @@ async def test_vagrant_strictness_allows_rare_uncommon_common(self, app_with_ebi async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -555,7 +618,9 @@ async def test_vagrant_strictness_allows_rare_uncommon_common(self, app_with_ebi # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_rare_strictness_allows_uncommon_common(self, app_with_ebird_filtering): + async def test_rare_strictness_allows_uncommon_common( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should block vagrant and rare species with rare strictness.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -576,6 +641,7 @@ async def test_rare_strictness_allows_uncommon_common(self, app_with_ebird_filte async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -594,7 +660,9 @@ async def test_rare_strictness_allows_uncommon_common(self, app_with_ebird_filte # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_uncommon_strictness_allows_only_common(self, app_with_ebird_filtering): + async def test_uncommon_strictness_allows_only_common( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should only allow common species with uncommon strictness.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -615,6 +683,7 @@ async def test_uncommon_strictness_allows_only_common(self, app_with_ebird_filte async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -633,7 +702,9 @@ async def test_uncommon_strictness_allows_only_common(self, app_with_ebird_filte # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_common_strictness_allows_only_common(self, app_with_ebird_filtering): + async def test_common_strictness_allows_only_common( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should only allow common species with common strictness.""" config = Container.config() config.ebird_filtering.detection_mode = "filter" @@ -654,6 +725,7 @@ async def test_common_strictness_allows_only_common(self, app_with_ebird_filteri async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( diff --git a/tests/integration/test_ebird_detection_filtering_simple.py b/tests/integration/test_ebird_detection_filtering_simple.py index e1b8d8e2..6e95b6b6 100644 --- a/tests/integration/test_ebird_detection_filtering_simple.py +++ b/tests/integration/test_ebird_detection_filtering_simple.py @@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +import redis.asyncio from dependency_injector import providers from httpx import ASGITransport, AsyncClient @@ -15,6 +16,7 @@ from birdnetpi.database.core import CoreDatabaseService from birdnetpi.database.ebird import EBirdRegionService from birdnetpi.releases.registry_service import BoundingBox, RegionPackInfo, RegistryService +from birdnetpi.utils.auth import AdminUser, AuthService, pwd_context from birdnetpi.utils.cache import Cache from birdnetpi.web.core.container import Container from birdnetpi.web.core.factory import create_app @@ -97,6 +99,41 @@ async def app_with_ebird_filtering(mock_ebird_service, path_resolver, tmp_path): ) Container.cache_service.override(providers.Singleton(lambda: mock_cache)) + # Mock redis client with in-memory storage for sessions + mock_redis = AsyncMock(spec=redis.asyncio.Redis) + redis_storage = {} + + async def mock_set(key, value, ex=None): + redis_storage[key] = value + return True + + async def mock_get(key): + return redis_storage.get(key) + + async def mock_delete(key): + redis_storage.pop(key, None) + return True + + mock_redis.set = AsyncMock(spec=object, side_effect=mock_set) + mock_redis.get = AsyncMock(spec=object, side_effect=mock_get) + mock_redis.delete = AsyncMock(spec=object, side_effect=mock_delete) + mock_redis.close = AsyncMock(spec=object) + Container.redis_client.override(providers.Singleton(lambda: mock_redis)) + + # Mock auth service + mock_auth_service = MagicMock(spec=AuthService) + mock_auth_service.admin_exists.return_value = True + mock_admin = AdminUser( + username="admin", + password_hash=pwd_context.hash("testpassword"), + created_at=datetime.now(UTC), + ) + mock_auth_service.load_admin_user.return_value = mock_admin + mock_auth_service.verify_password.side_effect = lambda plain, hashed: pwd_context.verify( + plain, hashed + ) + Container.auth_service.override(providers.Singleton(lambda: mock_auth_service)) + # Override the eBird service in the container BEFORE creating app Container.ebird_region_service.override(providers.Singleton(lambda: mock_ebird_service)) @@ -144,6 +181,8 @@ async def app_with_ebird_filtering(mock_ebird_service, path_resolver, tmp_path): Container.config.reset_override() Container.core_database.reset_override() Container.cache_service.reset_override() + Container.redis_client.reset_override() + Container.auth_service.reset_override() Container.ebird_region_service.reset_override() Container.registry_service.reset_override() @@ -153,7 +192,9 @@ class TestEBirdFilteringIntegration: # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_vagrant_species_blocked_in_filter_mode(self, app_with_ebird_filtering): + async def test_vagrant_species_blocked_in_filter_mode( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should block vagrant species when filtering is enabled.""" # Configure mock to return vagrant tier mock_service = app_with_ebird_filtering._mock_ebird_service @@ -162,6 +203,7 @@ async def test_vagrant_species_blocked_in_filter_mode(self, app_with_ebird_filte async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -179,7 +221,9 @@ async def test_vagrant_species_blocked_in_filter_mode(self, app_with_ebird_filte # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_common_species_allowed(self, app_with_ebird_filtering): + async def test_common_species_allowed( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should allow common species regardless of strictness.""" # Configure mock to return common tier mock_service = app_with_ebird_filtering._mock_ebird_service @@ -188,6 +232,7 @@ async def test_common_species_allowed(self, app_with_ebird_filtering): async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -202,7 +247,7 @@ async def test_common_species_allowed(self, app_with_ebird_filtering): # Detection should be created assert data["detection_id"] is not None - async def test_filtering_disabled(self, app_with_temp_data): + async def test_filtering_disabled(self, app_with_temp_data, authenticate_async_client): """Should allow all detections when filtering is disabled.""" # Ensure filtering is disabled config = Container.config() @@ -211,6 +256,7 @@ async def test_filtering_disabled(self, app_with_temp_data): async with AsyncClient( transport=ASGITransport(app=app_with_temp_data), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -226,7 +272,9 @@ async def test_filtering_disabled(self, app_with_temp_data): # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_unknown_species_behavior(self, app_with_ebird_filtering): + async def test_unknown_species_behavior( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should handle unknown species according to configuration.""" # Mock service returns None (species not found) # Config has unknown_species_behavior = "allow" by default @@ -234,6 +282,7 @@ async def test_unknown_species_behavior(self, app_with_ebird_filtering): async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( @@ -250,7 +299,9 @@ async def test_unknown_species_behavior(self, app_with_ebird_filtering): # Using app_with_ebird_filtering instead of app_with_temp_data because we need # eBird filtering enabled with mocked eBird service for this integration test - async def test_warn_mode_creates_detection(self, app_with_ebird_filtering): + async def test_warn_mode_creates_detection( + self, app_with_ebird_filtering, authenticate_async_client + ): """Should create detection in warn mode even when species would be filtered.""" # Set mode to warn config = Container.config() @@ -263,6 +314,7 @@ async def test_warn_mode_creates_detection(self, app_with_ebird_filtering): async with AsyncClient( transport=ASGITransport(app=app_with_ebird_filtering), base_url="http://test" ) as client: + await authenticate_async_client(client) response = await client.post( "/api/detections/", json=create_detection_payload( diff --git a/tests/integration/test_routes_integration.py b/tests/integration/test_routes_integration.py index 2d6a1213..105041d8 100644 --- a/tests/integration/test_routes_integration.py +++ b/tests/integration/test_routes_integration.py @@ -49,9 +49,11 @@ def integration_app(app_with_temp_data, tmp_path): @pytest.fixture -def integration_client(integration_app): +def integration_client(integration_app, authenticate_sync_client): """Create test client with integration app.""" - return TestClient(integration_app) + client = TestClient(integration_app) + authenticate_sync_client(client) + return client class TestSystemRoutesIntegration: diff --git a/tests/integration/test_update_integration.py b/tests/integration/test_update_integration.py index ac7eeb31..33c6ca5e 100644 --- a/tests/integration/test_update_integration.py +++ b/tests/integration/test_update_integration.py @@ -14,8 +14,8 @@ @pytest.fixture -def client(app_with_temp_data): - """Create test client from app.""" +def client(app_with_temp_data, authenticate_sync_client): + """Create authenticated test client from app.""" # Mount static files to avoid template rendering errors # Create a temporary static directory @@ -28,7 +28,10 @@ def client(app_with_temp_data): # Mount the static files app_with_temp_data.mount("/static", StaticFiles(directory=static_dir), name="static") - return TestClient(app_with_temp_data) + test_client = TestClient(app_with_temp_data) + authenticate_sync_client(test_client) + + return test_client @pytest.fixture diff --git a/tests/integration/test_update_integration_unhappy.py b/tests/integration/test_update_integration_unhappy.py index 03e0c876..db7b81f3 100644 --- a/tests/integration/test_update_integration_unhappy.py +++ b/tests/integration/test_update_integration_unhappy.py @@ -12,8 +12,8 @@ @pytest.fixture -def client(app_with_temp_data): - """Create test client from app.""" +def client(app_with_temp_data, authenticate_sync_client): + """Create authenticated test client from app.""" # Mount static files to avoid template rendering errors # Create a temporary static directory @@ -26,7 +26,10 @@ def client(app_with_temp_data): # Mount the static files app_with_temp_data.mount("/static", StaticFiles(directory=static_dir), name="static") - return TestClient(app_with_temp_data) + test_client = TestClient(app_with_temp_data) + authenticate_sync_client(test_client) + + return test_client class TestAPIErrorHandling: diff --git a/uv.lock b/uv.lock index c6ca67eb..e59befef 100644 --- a/uv.lock +++ b/uv.lock @@ -158,6 +158,39 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a4/3a/52ec5cf9ed71e3a068a208603e38471b686a59b548dda771f2975166e4a9/apprise-1.2.1-py2.py3-none-any.whl", hash = "sha256:679fb5d6232ec7748eef308bda9fe0f0707fa5e48eab247bfde83c3a024a21d1", size = 1109380, upload-time = "2022-12-28T14:58:19.436Z" }, ] +[[package]] +name = "argon2-cffi" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "argon2-cffi-bindings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/89/ce5af8a7d472a67cc819d5d998aa8c82c5d860608c4db9f46f1162d7dab9/argon2_cffi-25.1.0.tar.gz", hash = "sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1", size = 45706, upload-time = "2025-06-03T06:55:32.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/d3/a8b22fa575b297cd6e3e3b0155c7e25db170edf1c74783d6a31a2490b8d9/argon2_cffi-25.1.0-py3-none-any.whl", hash = "sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741", size = 14657, upload-time = "2025-06-03T06:55:30.804Z" }, +] + +[[package]] +name = "argon2-cffi-bindings" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/2d/db8af0df73c1cf454f71b2bbe5e356b8c1f8041c979f505b3d3186e520a9/argon2_cffi_bindings-25.1.0.tar.gz", hash = "sha256:b957f3e6ea4d55d820e40ff76f450952807013d361a65d7f28acc0acbf29229d", size = 1783441, upload-time = "2025-07-30T10:02:05.147Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/57/96b8b9f93166147826da5f90376e784a10582dd39a393c99bb62cfcf52f0/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:aecba1723ae35330a008418a91ea6cfcedf6d31e5fbaa056a166462ff066d500", size = 54121, upload-time = "2025-07-30T10:01:50.815Z" }, + { url = "https://files.pythonhosted.org/packages/0a/08/a9bebdb2e0e602dde230bdde8021b29f71f7841bd54801bcfd514acb5dcf/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2630b6240b495dfab90aebe159ff784d08ea999aa4b0d17efa734055a07d2f44", size = 29177, upload-time = "2025-07-30T10:01:51.681Z" }, + { url = "https://files.pythonhosted.org/packages/b6/02/d297943bcacf05e4f2a94ab6f462831dc20158614e5d067c35d4e63b9acb/argon2_cffi_bindings-25.1.0-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:7aef0c91e2c0fbca6fc68e7555aa60ef7008a739cbe045541e438373bc54d2b0", size = 31090, upload-time = "2025-07-30T10:01:53.184Z" }, + { url = "https://files.pythonhosted.org/packages/c1/93/44365f3d75053e53893ec6d733e4a5e3147502663554b4d864587c7828a7/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1e021e87faa76ae0d413b619fe2b65ab9a037f24c60a1e6cc43457ae20de6dc6", size = 81246, upload-time = "2025-07-30T10:01:54.145Z" }, + { url = "https://files.pythonhosted.org/packages/09/52/94108adfdd6e2ddf58be64f959a0b9c7d4ef2fa71086c38356d22dc501ea/argon2_cffi_bindings-25.1.0-cp39-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e924cfc503018a714f94a49a149fdc0b644eaead5d1f089330399134fa028a", size = 87126, upload-time = "2025-07-30T10:01:55.074Z" }, + { url = "https://files.pythonhosted.org/packages/72/70/7a2993a12b0ffa2a9271259b79cc616e2389ed1a4d93842fac5a1f923ffd/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c87b72589133f0346a1cb8d5ecca4b933e3c9b64656c9d175270a000e73b288d", size = 80343, upload-time = "2025-07-30T10:01:56.007Z" }, + { url = "https://files.pythonhosted.org/packages/78/9a/4e5157d893ffc712b74dbd868c7f62365618266982b64accab26bab01edc/argon2_cffi_bindings-25.1.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:1db89609c06afa1a214a69a462ea741cf735b29a57530478c06eb81dd403de99", size = 86777, upload-time = "2025-07-30T10:01:56.943Z" }, + { url = "https://files.pythonhosted.org/packages/74/cd/15777dfde1c29d96de7f18edf4cc94c385646852e7c7b0320aa91ccca583/argon2_cffi_bindings-25.1.0-cp39-abi3-win32.whl", hash = "sha256:473bcb5f82924b1becbb637b63303ec8d10e84c8d241119419897a26116515d2", size = 27180, upload-time = "2025-07-30T10:01:57.759Z" }, + { url = "https://files.pythonhosted.org/packages/e2/c6/a759ece8f1829d1f162261226fbfd2c6832b3ff7657384045286d2afa384/argon2_cffi_bindings-25.1.0-cp39-abi3-win_amd64.whl", hash = "sha256:a98cd7d17e9f7ce244c0803cad3c23a7d379c301ba618a5fa76a67d116618b98", size = 31715, upload-time = "2025-07-30T10:01:58.56Z" }, + { url = "https://files.pythonhosted.org/packages/42/b9/f8d6fa329ab25128b7e98fd83a3cb34d9db5b059a9847eddb840a0af45dd/argon2_cffi_bindings-25.1.0-cp39-abi3-win_arm64.whl", hash = "sha256:b0fdbcf513833809c882823f98dc2f931cf659d9a1429616ac3adebb49f5db94", size = 27149, upload-time = "2025-07-30T10:01:59.329Z" }, +] + [[package]] name = "astral" version = "3.2" @@ -249,6 +282,7 @@ dependencies = [ { name = "packaging" }, { name = "paho-mqtt" }, { name = "pandas" }, + { name = "passlib", extra = ["argon2"] }, { name = "pip" }, { name = "plotly" }, { name = "psutil" }, @@ -264,6 +298,7 @@ dependencies = [ { name = "sqladmin" }, { name = "sqlalchemy" }, { name = "sqlmodel" }, + { name = "starsessions", extra = ["redis"] }, { name = "structlog" }, { name = "suntime" }, { name = "tqdm" }, @@ -339,6 +374,7 @@ requires-dist = [ { name = "packaging", specifier = ">=25.0" }, { name = "paho-mqtt" }, { name = "pandas" }, + { name = "passlib", extras = ["argon2"], specifier = ">=1.7.4" }, { name = "pillow", marker = "(platform_machine == 'aarch64' and extra == 'epaper') or (platform_machine == 'armv7l' and extra == 'epaper')", specifier = ">=10.0.0" }, { name = "pip" }, { name = "plotly" }, @@ -357,6 +393,7 @@ requires-dist = [ { name = "sqladmin", specifier = ">=0.21.0" }, { name = "sqlalchemy" }, { name = "sqlmodel", specifier = ">=0.0.24" }, + { name = "starsessions", extras = ["redis"], specifier = ">=2.2.1" }, { name = "structlog", specifier = ">=25.4.0" }, { name = "suntime" }, { name = "tqdm", specifier = ">=4.67.1" }, @@ -805,6 +842,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, ] +[[package]] +name = "itsdangerous" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9c/cb/8ac0172223afbccb63986cc25049b154ecfb5e85932587206f42317be31d/itsdangerous-2.2.0.tar.gz", hash = "sha256:e0050c0b7da1eea53ffaf149c0cfbb5c6e2e2b69c4bef22c81fa6eb73e5f6173", size = 54410, upload-time = "2024-04-16T21:28:15.614Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" }, +] + [[package]] name = "jetson-gpio" version = "2.1.12" @@ -1195,6 +1241,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/75/cb/09d5f9bf7c8659af134ae0ffc1a349038a5d0ff93e45aedc225bde2872a3/pandas_stubs-2.3.0.250703-py3-none-any.whl", hash = "sha256:a9265fc69909f0f7a9cabc5f596d86c9d531499fed86b7838fd3278285d76b81", size = 154719, upload-time = "2025-07-02T17:49:10.697Z" }, ] +[[package]] +name = "passlib" +version = "1.7.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b6/06/9da9ee59a67fae7761aab3ccc84fa4f3f33f125b370f1ccdb915bf967c11/passlib-1.7.4.tar.gz", hash = "sha256:defd50f72b65c5402ab2c573830a6978e5f202ad0d984793c8dde2c4152ebe04", size = 689844, upload-time = "2020-10-08T19:00:52.121Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/a4/ab6b7589382ca3df236e03faa71deac88cae040af60c071a78d254a62172/passlib-1.7.4-py2.py3-none-any.whl", hash = "sha256:aa6bca462b8d8bda89c70b382f0c298a20b5560af6cbfa2dce410c0a2fb669f1", size = 525554, upload-time = "2020-10-08T19:00:49.856Z" }, +] + +[package.optional-dependencies] +argon2 = [ + { name = "argon2-cffi" }, +] + [[package]] name = "pillow" version = "11.3.0" @@ -1877,6 +1937,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/82/95/38ef0cd7fa11eaba6a99b3c4f5ac948d8bc6ff199aabd327a29cc000840c/starlette-0.47.1-py3-none-any.whl", hash = "sha256:5e11c9f5c7c3f24959edbf2dffdc01bba860228acf657129467d8a7468591527", size = 72747, upload-time = "2025-06-21T04:03:15.705Z" }, ] +[[package]] +name = "starsessions" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "itsdangerous" }, + { name = "starlette" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/a1/dd738cd47b7a1c681cae49c4f7c88cc953b2aca4de455c2aacda6652e7ce/starsessions-2.2.1.tar.gz", hash = "sha256:ce5e4448d9bf2c76222e56cd099ad92d22313e8a4def612e22b71a122cc11da0", size = 15048, upload-time = "2024-10-23T09:01:12.343Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/ce/fc699345a3cdfb4425b5dc1e446f1b49702ac55907a6a4d5d806f2512dae/starsessions-2.2.1-py3-none-any.whl", hash = "sha256:8097b33d70017b2d2331307f0ea923620b5bfb847118d2e5872805d0c1c16f83", size = 14621, upload-time = "2024-10-23T09:01:10.88Z" }, +] + +[package.optional-dependencies] +redis = [ + { name = "redis" }, +] + [[package]] name = "structlog" version = "25.4.0"