diff --git a/CHANGELOG.md b/CHANGELOG.md index c27ec532..223ab688 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Sidebar collapsed sections (initiatives, tags) no longer mount child DOM nodes — lazy-render on expand - Skip `useSortable` hooks when drag-and-drop is disabled (sorting/grouping active) for better scroll performance - Keep previous React Query data as placeholder for snappier page navigation +- - Replaced `sort_by`/`sort_dir` string parameters on the tasks list endpoint with a structured `sorting` JSON parameter (`SortField[]`) — enables multi-column sorting (e.g. date group then due date) using the same pattern as `conditions` uses `FilterCondition[]` +- Frontend task tables (`useGlobalTasksTable`, `TagTasksTable`, dashboard, route loaders) now pass `SortField[]` arrays instead of individual sort strings ## [0.31.5] - 2026-02-20 diff --git a/backend/app/api/v1/endpoints/tasks.py b/backend/app/api/v1/endpoints/tasks.py index 92048aed..7d3694ce 100644 --- a/backend/app/api/v1/endpoints/tasks.py +++ b/backend/app/api/v1/endpoints/tasks.py @@ -6,6 +6,17 @@ from sqlalchemy import case, func, text from sqlmodel import select, delete +from app.db.query import ( + apply_filters, + apply_sorting, + build_paginated_response, + extract_condition_value, + paginated_query, + parse_conditions, + parse_sort_fields, +) +from app.schemas.query import FilterOp + from app.api.deps import ( RLSSessionDep, SessionDep, @@ -40,7 +51,7 @@ from app.services.recurrence import get_next_due_date from app.services import task_statuses as task_statuses_service from app.services import ai_generation as ai_generation_service -from app.core.messages import ProjectMessages, TaskMessages, SubtaskMessages +from app.core.messages import ProjectMessages, QueryMessages, TaskMessages, SubtaskMessages router = APIRouter() @@ -91,33 +102,88 @@ def _date_group_expression(): } -def _apply_task_sort(statement, sort_by: Optional[str], sort_dir: Optional[str]): - """Apply ORDER BY clause based on sort_by/sort_dir params, with fallback. +TASK_DEFAULT_SORT = [(Task.sort_order, "asc"), (Task.id, "asc")] + - Supports multi-sort via comma-separated values: - sort_by=date_group,due_date sort_dir=asc,asc +def _build_task_filter_fields(*, guild_id: int, current_user_id: int) -> dict: + """Build allowed_fields dict from Task model columns plus callable overrides. + + Every column on the Task table is automatically available as a filter + field (e.g. ``project_id``, ``priority``, ``due_date``, ``title``). + Virtual fields that require subqueries (``status_category``, + ``assignee_ids``, ``tag_ids``, ``initiative_ids``) are added as + callable handlers that receive ``(op, value)`` and return a SA clause. """ - if not sort_by: - return statement.order_by(Task.sort_order.asc(), Task.id.asc()) + # Auto-populate from model columns + fields: dict = { + col.name: getattr(Task, col.name) + for col in Task.__table__.columns + } + + # Callable overrides for virtual / cross-table fields + def _status_category_handler(op: FilterOp, value): + if not value: + return None + subq = select(TaskStatus.id).where( + TaskStatus.category.in_(tuple(value)) + ) + return Task.task_status_id.in_(subq) + + def _assignee_ids_handler(op: FilterOp, value): + if not value: + return None + user_ids = [] + for aid in value: + if aid == "me": + user_ids.append(current_user_id) + else: + try: + user_ids.append(int(aid)) + except (ValueError, TypeError): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=TaskMessages.INVALID_ASSIGNEE_ID, + ) + if not user_ids: + return None + subq = ( + select(TaskAssignee.task_id) + .where(TaskAssignee.user_id.in_(tuple(user_ids))) + ) + return Task.id.in_(subq) + + def _tag_ids_handler(op: FilterOp, value): + if not value: + return None + subq = ( + select(TaskTag.task_id) + .join(Tag, Tag.id == TaskTag.tag_id) + .where( + TaskTag.tag_id.in_(tuple(value)), + Tag.guild_id == guild_id, + ) + .distinct() + ) + return Task.id.in_(subq) + + def _initiative_ids_handler(op: FilterOp, value): + if not value: + return None + subq = select(Project.id).where( + Project.initiative_id.in_(tuple(value)) + ) + return Task.project_id.in_(subq) + + fields["status_category"] = _status_category_handler + fields["assignee_ids"] = _assignee_ids_handler + fields["tag_ids"] = _tag_ids_handler + fields["initiative_ids"] = _initiative_ids_handler + + return fields - fields = [f.strip() for f in sort_by.split(",") if f.strip()] - dirs = [d.strip() for d in (sort_dir or "").split(",")] - has_valid = False - for i, field_name in enumerate(fields): - col = TASK_SORT_FIELDS.get(field_name) - if col is None: - continue - direction = dirs[i] if i < len(dirs) else "asc" - order = col.desc() if direction == "desc" else col.asc() - statement = statement.order_by(order.nulls_last()) - has_valid = True - if not has_valid: - return statement.order_by(Task.sort_order.asc(), Task.id.asc()) - # tiebreaker - return statement.order_by(Task.id.asc()) subtasks_router = APIRouter() GuildContextDep = Annotated[GuildContext, Depends(get_guild_membership)] @@ -577,10 +643,8 @@ async def _list_global_tasks( include_archived: bool = False, page: int = 1, page_size: int = 20, - sort_by: Optional[str] = None, - sort_dir: Optional[str] = None, -) -> tuple[list[Task], int]: - # Base conditions shared by count and data queries + sort_fields: list | None = None, +) -> tuple[list[Task], int, int]: conditions = [ TaskAssignee.user_id == current_user.id, GuildMembership.user_id == current_user.id, @@ -614,12 +678,9 @@ def _base_query(stmt): ) return stmt.where(*conditions) - # Count query count_subq = _base_query(select(Task.id)).subquery() count_stmt = select(func.count()).select_from(count_subq) - total_count = (await session.exec(count_stmt)).one() - # Data query statement = _base_query(select(Task)).options( selectinload(Task.project) .selectinload(Project.initiative) @@ -628,13 +689,9 @@ def _base_query(stmt): selectinload(Task.task_status), selectinload(Task.tag_links).selectinload(TaskTag.tag), ) - statement = _apply_task_sort(statement, sort_by, sort_dir) - - if page_size > 0: - statement = statement.offset((page - 1) * page_size).limit(page_size) + statement = apply_sorting(statement, Task, sort_fields=sort_fields, allowed_fields=TASK_SORT_FIELDS, default_sort=TASK_DEFAULT_SORT) - result = await session.exec(statement) - return result.all(), total_count + return await paginated_query(session, statement, count_stmt, page, page_size) async def _list_global_created_tasks( @@ -649,11 +706,9 @@ async def _list_global_created_tasks( include_archived: bool = False, page: int = 1, page_size: int = 20, - sort_by: Optional[str] = None, - sort_dir: Optional[str] = None, -) -> tuple[list[Task], int]: + sort_fields: list | None = None, +) -> tuple[list[Task], int, int]: """List tasks created by the current user across all guilds they belong to.""" - # Base conditions shared by count and data queries conditions = [ Task.created_by_id == current_user.id, GuildMembership.user_id == current_user.id, @@ -671,7 +726,6 @@ async def _list_global_created_tasks( if guild_ids: conditions.append(Initiative.guild_id.in_(tuple(guild_ids))) - # Build base join chain (no TaskAssignee join needed) def _base_query(stmt): stmt = ( stmt @@ -686,12 +740,9 @@ def _base_query(stmt): ) return stmt.where(*conditions) - # Count query count_subq = _base_query(select(Task.id)).subquery() count_stmt = select(func.count()).select_from(count_subq) - total_count = (await session.exec(count_stmt)).one() - # Data query statement = _base_query(select(Task)).options( selectinload(Task.project) .selectinload(Project.initiative) @@ -700,13 +751,9 @@ def _base_query(stmt): selectinload(Task.task_status), selectinload(Task.tag_links).selectinload(TaskTag.tag), ) - statement = _apply_task_sort(statement, sort_by, sort_dir) + statement = apply_sorting(statement, Task, sort_fields=sort_fields, allowed_fields=TASK_SORT_FIELDS, default_sort=TASK_DEFAULT_SORT) - if page_size > 0: - statement = statement.offset((page - 1) * page_size).limit(page_size) - - result = await session.exec(statement) - return result.all(), total_count + return await paginated_query(session, statement, count_stmt, page, page_size) @router.get("/", response_model=TaskListResponse) @@ -714,23 +761,49 @@ async def list_tasks( session: RLSSessionDep, current_user: Annotated[User, Depends(get_current_active_user)], guild_context: GuildContextDep, - project_id: Optional[int] = Query(default=None), scope: Annotated[Literal["global", "global_created"] | None, Query()] = None, - assignee_ids: Optional[List[str]] = Query(default=None), - task_status_ids: Optional[List[int]] = Query(default=None), - priorities: Optional[List[TaskPriority]] = Query(default=None), - status_category: Optional[List[TaskStatusCategory]] = Query(default=None), - initiative_ids: Optional[List[int]] = Query(default=None), - guild_ids: Optional[List[int]] = Query(default=None), - tag_ids: Optional[List[int]] = Query(default=None, description="Filter by tag IDs"), + conditions: Optional[str] = Query( + default=None, + description=( + 'JSON list of filter conditions. Each object: ' + '{"field": "", "op": "", "value": }. ' + "Any Task column is valid plus virtual fields: " + "status_category, assignee_ids, tag_ids, initiative_ids." + ), + ), include_archived: bool = Query(default=False, description="Include archived tasks"), page: int = Query(default=1, ge=1), page_size: int = Query(default=20, ge=0, le=100), - sort_by: Optional[str] = Query(default=None, description="Sort field(s), comma-separated: sort_order, title, due_date, start_date, priority, created_at, updated_at, date_group"), - sort_dir: Optional[str] = Query(default=None, description="Sort direction(s), comma-separated: asc or desc (one per sort field)"), + sorting: Optional[str] = Query( + default=None, + description='JSON list of sort fields: [{"field": "due_date", "dir": "desc"}]', + ), ) -> TaskListResponse: + try: + user_conditions = parse_conditions(conditions) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=QueryMessages.INVALID_CONDITIONS, + ) + + try: + sort_fields = parse_sort_fields(sorting) or None + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=QueryMessages.INVALID_SORT_FIELDS, + ) + + # Extract values needed by global paths and access control + project_id = extract_condition_value(user_conditions, "project_id") + priorities = extract_condition_value(user_conditions, "priority") + status_category = extract_condition_value(user_conditions, "status_category") + initiative_ids = extract_condition_value(user_conditions, "initiative_ids") + guild_ids = extract_condition_value(user_conditions, "guild_ids") + if scope == "global": - tasks, total_count = await _list_global_tasks( + tasks, total_count, actual_page = await _list_global_tasks( session, current_user, project_id=project_id, @@ -741,30 +814,18 @@ async def list_tasks( include_archived=include_archived, page=page, page_size=page_size, - sort_by=sort_by, - sort_dir=sort_dir, + sort_fields=sort_fields, ) await _annotate_tasks(session, tasks) _annotate_task_tags(tasks) items = [_task_to_list_read(task) for task in tasks] - if page_size > 0: - has_next = page * page_size < total_count - else: - # page_size=0 means "all rows, no pagination" - has_next = False - page = 1 - return TaskListResponse( - items=items, - total_count=total_count, - page=page, - page_size=page_size, - has_next=has_next, - sort_by=sort_by, - sort_dir=sort_dir, - ) + return TaskListResponse(**build_paginated_response( + items=items, total_count=total_count, page=actual_page, + page_size=page_size, sorting=sorting, + )) elif scope == "global_created": - tasks, total_count = await _list_global_created_tasks( + tasks, total_count, actual_page = await _list_global_created_tasks( session, current_user, project_id=project_id, @@ -775,47 +836,21 @@ async def list_tasks( include_archived=include_archived, page=page, page_size=page_size, - sort_by=sort_by, - sort_dir=sort_dir, + sort_fields=sort_fields, ) await _annotate_tasks(session, tasks) _annotate_task_tags(tasks) items = [_task_to_list_read(task) for task in tasks] - if page_size > 0: - has_next = page * page_size < total_count - else: - has_next = False - page = 1 - return TaskListResponse( - items=items, - total_count=total_count, - page=page, - page_size=page_size, - has_next=has_next, - sort_by=sort_by, - sort_dir=sort_dir, - ) + return TaskListResponse(**build_paginated_response( + items=items, total_count=total_count, page=actual_page, + page_size=page_size, sorting=sorting, + )) # Non-global (guild-scoped) path - conditions = [Initiative.guild_id == guild_context.guild_id] + access_conditions = [Initiative.guild_id == guild_context.guild_id] if not include_archived: - conditions.append(Task.is_archived.is_(False)) - - if project_id is not None: - conditions.append(Task.project_id == project_id) - - if task_status_ids: - conditions.append(Task.task_status_id.in_(tuple(task_status_ids))) - - if priorities: - conditions.append(Task.priority.in_(tuple(priorities))) - - if initiative_ids: - conditions.append(Project.initiative_id.in_(tuple(initiative_ids))) - - if guild_ids: - conditions.append(Initiative.guild_id.in_(tuple(guild_ids))) + access_conditions.append(Task.is_archived.is_(False)) allowed_ids = await _allowed_project_ids( session, @@ -825,48 +860,26 @@ async def list_tasks( ) if allowed_ids is not None: if not allowed_ids: - return TaskListResponse(items=[], total_count=0, page=page, page_size=page_size, has_next=False, sort_by=sort_by, sort_dir=sort_dir) - conditions.append(Task.project_id.in_(tuple(allowed_ids))) + return TaskListResponse(**build_paginated_response( + items=[], total_count=0, page=1, + page_size=page_size, sorting=sorting, + )) + access_conditions.append(Task.project_id.in_(tuple(allowed_ids))) + + filter_fields = _build_task_filter_fields( + guild_id=guild_context.guild_id, + current_user_id=current_user.id, + ) def _build_non_global_query(stmt): stmt = stmt.join(Task.project).join(Project.initiative) - if assignee_ids: - user_ids = [] - for assignee_id in assignee_ids: - if assignee_id == "me": - user_ids.append(current_user.id) - else: - try: - user_ids.append(int(assignee_id)) - except ValueError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=TaskMessages.INVALID_ASSIGNEE_ID) - if user_ids: - stmt = stmt.join(TaskAssignee, TaskAssignee.task_id == Task.id).where( - TaskAssignee.user_id.in_(tuple(user_ids)) - ) - if status_category: - stmt = stmt.join(TaskStatus, Task.task_status_id == TaskStatus.id).where( - TaskStatus.category.in_(tuple(status_category)) - ) - if tag_ids: - tag_subquery = ( - select(TaskTag.task_id) - .join(Tag, Tag.id == TaskTag.tag_id) - .where( - TaskTag.tag_id.in_(tuple(tag_ids)), - Tag.guild_id == guild_context.guild_id, - ) - .distinct() - ) - stmt = stmt.where(Task.id.in_(tag_subquery)) - return stmt.where(*conditions) + stmt = stmt.where(*access_conditions) + stmt = apply_filters(stmt, Task, user_conditions, allowed_fields=filter_fields) + return stmt - # Count query count_subq = _build_non_global_query(select(Task.id)).subquery() count_stmt = select(func.count()).select_from(count_subq) - total_count = (await session.exec(count_stmt)).one() - # Data query statement = _build_non_global_query(select(Task)).options( selectinload(Task.project) .selectinload(Project.initiative) @@ -875,31 +888,16 @@ def _build_non_global_query(stmt): selectinload(Task.task_status), selectinload(Task.tag_links).selectinload(TaskTag.tag), ) - statement = _apply_task_sort(statement, sort_by, sort_dir) - - if page_size > 0: - statement = statement.offset((page - 1) * page_size).limit(page_size) + statement = apply_sorting(statement, Task, sort_fields=sort_fields, allowed_fields=TASK_SORT_FIELDS, default_sort=TASK_DEFAULT_SORT) - result = await session.exec(statement) - tasks = result.all() + tasks, total_count, actual_page = await paginated_query(session, statement, count_stmt, page, page_size) await _annotate_tasks(session, tasks) _annotate_task_tags(tasks) items = [_task_to_list_read(task) for task in tasks] - if page_size > 0: - has_next = page * page_size < total_count - else: - # page_size=0 means "all rows, no pagination" - has_next = False - page = 1 - return TaskListResponse( - items=items, - total_count=total_count, - page=page, - page_size=page_size, - has_next=has_next, - sort_by=sort_by, - sort_dir=sort_dir, - ) + return TaskListResponse(**build_paginated_response( + items=items, total_count=total_count, page=actual_page, + page_size=page_size, sorting=sorting, + )) @router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED) diff --git a/backend/app/api/v1/endpoints/tasks_global_created_test.py b/backend/app/api/v1/endpoints/tasks_global_created_test.py index 67920616..96cbbae1 100644 --- a/backend/app/api/v1/endpoints/tasks_global_created_test.py +++ b/backend/app/api/v1/endpoints/tasks_global_created_test.py @@ -5,6 +5,8 @@ tasks created by the current user across all guilds they belong to. """ +import json + import pytest from httpx import AsyncClient from sqlmodel.ext.asyncio.session import AsyncSession @@ -181,8 +183,9 @@ async def test_list_global_created_tasks_priority_filter( await session.commit() headers = get_guild_headers(guild, user) + conditions = json.dumps([{"field": "priority", "op": "in_", "value": ["high"]}]) response = await client.get( - "/api/v1/tasks/?scope=global_created&priorities=high", headers=headers + f"/api/v1/tasks/?scope=global_created&conditions={conditions}", headers=headers ) assert response.status_code == 200 @@ -213,8 +216,9 @@ async def test_list_global_created_tasks_guild_filter( # Filter to guild1 only headers = get_guild_headers(guild1, user) + conditions = json.dumps([{"field": "guild_ids", "op": "in_", "value": [guild1.id]}]) response = await client.get( - f"/api/v1/tasks/?scope=global_created&guild_ids={guild1.id}", + f"/api/v1/tasks/?scope=global_created&conditions={conditions}", headers=headers, ) diff --git a/backend/app/api/v1/endpoints/tasks_test.py b/backend/app/api/v1/endpoints/tasks_test.py index ee32ce07..88e21b3e 100644 --- a/backend/app/api/v1/endpoints/tasks_test.py +++ b/backend/app/api/v1/endpoints/tasks_test.py @@ -12,6 +12,8 @@ - Task reordering """ +import json + import pytest from httpx import AsyncClient from sqlmodel.ext.asyncio.session import AsyncSession @@ -95,8 +97,9 @@ async def test_list_tasks_in_project(client: AsyncClient, session: AsyncSession) task2 = await _create_task(session, project, "Task 2") headers = get_guild_headers(guild, user) + conditions = json.dumps([{"field": "project_id", "op": "eq", "value": project.id}]) response = await client.get( - f"/api/v1/tasks/?project_id={project.id}", headers=headers + f"/api/v1/tasks/?conditions={conditions}", headers=headers ) assert response.status_code == 200 @@ -555,7 +558,8 @@ async def test_list_my_tasks(client: AsyncClient, session: AsyncSession): await session.commit() headers = get_guild_headers(guild, user) - response = await client.get("/api/v1/tasks/?assignee_ids=me", headers=headers) + conditions = json.dumps([{"field": "assignee_ids", "op": "in_", "value": ["me"]}]) + response = await client.get(f"/api/v1/tasks/?conditions={conditions}", headers=headers) assert response.status_code == 200 data = response.json()["items"] @@ -595,8 +599,12 @@ async def test_filter_tasks_by_status(client: AsyncClient, session: AsyncSession await session.commit() headers = get_guild_headers(guild, user) + conditions = json.dumps([ + {"field": "project_id", "op": "eq", "value": project.id}, + {"field": "task_status_id", "op": "in_", "value": [todo_status.id]}, + ]) response = await client.get( - f"/api/v1/tasks/?project_id={project.id}&task_status_ids={todo_status.id}", + f"/api/v1/tasks/?conditions={conditions}", headers=headers, ) @@ -661,8 +669,9 @@ async def test_rolling_recurrence_preserves_due_time( assert response.status_code == 200 # Fetch all tasks to find the newly created recurring task + conditions = json.dumps([{"field": "project_id", "op": "eq", "value": project.id}]) response = await client.get( - f"/api/v1/tasks/?project_id={project.id}", headers=headers + f"/api/v1/tasks/?conditions={conditions}", headers=headers ) assert response.status_code == 200 tasks = response.json()["items"] @@ -736,8 +745,9 @@ async def test_fixed_recurrence_uses_original_due_date( assert response.status_code == 200 # Fetch all tasks + conditions = json.dumps([{"field": "project_id", "op": "eq", "value": project.id}]) response = await client.get( - f"/api/v1/tasks/?project_id={project.id}", headers=headers + f"/api/v1/tasks/?conditions={conditions}", headers=headers ) assert response.status_code == 200 tasks = response.json()["items"] @@ -810,8 +820,9 @@ async def test_rolling_recurrence_with_midnight_time( assert response.status_code == 200 # Fetch all tasks + conditions = json.dumps([{"field": "project_id", "op": "eq", "value": project.id}]) response = await client.get( - f"/api/v1/tasks/?project_id={project.id}", headers=headers + f"/api/v1/tasks/?conditions={conditions}", headers=headers ) assert response.status_code == 200 tasks = response.json()["items"] diff --git a/backend/app/core/messages.py b/backend/app/core/messages.py index 09d5631c..e668cd67 100644 --- a/backend/app/core/messages.py +++ b/backend/app/core/messages.py @@ -234,5 +234,10 @@ class ImportMessages: INVALID_STATUS_ID = "IMPORT_INVALID_STATUS_ID" +class QueryMessages: + INVALID_CONDITIONS = "QUERY_INVALID_CONDITIONS" + INVALID_SORT_FIELDS = "QUERY_INVALID_SORT_FIELDS" + + class NotificationMessages: NOT_FOUND = "NOTIFICATION_NOT_FOUND" diff --git a/backend/app/db/query.py b/backend/app/db/query.py new file mode 100644 index 00000000..757d61a8 --- /dev/null +++ b/backend/app/db/query.py @@ -0,0 +1,384 @@ +"""Reusable query utilities for filtering, sorting, and pagination. + +Provides composable functions that transform SQLAlchemy Select statements: +- parse_conditions: safely parses a JSON string into FilterCondition list +- apply_filters: adds WHERE clauses from FilterCondition/FilterGroup lists +- apply_sorting: adds ORDER BY clauses from SortField list or comma-separated strings +- apply_pagination: adds OFFSET/LIMIT +- paginated_query: executes count + data queries, clamps page, returns (items, total, page) +""" + +from __future__ import annotations + +import json +from typing import Any + +from pydantic import ValidationError +from sqlalchemy import Select, and_, asc, desc, not_, or_ +from sqlmodel.ext.asyncio.session import AsyncSession + +from app.schemas.query import FilterCondition, FilterGroup, FilterOp, SortField, SortDir + + +# Hard limits to prevent abuse via oversized payloads. +_MAX_CONDITIONS = 50 +_MAX_SORT_FIELDS = 10 +_MAX_RAW_LENGTH = 10_000 + + +def parse_conditions( + raw: str | None, + *, + max_conditions: int = _MAX_CONDITIONS, + max_length: int = _MAX_RAW_LENGTH, +) -> list[FilterCondition]: + """Safely parse a JSON-encoded list of filter conditions. + + Designed for use with query parameters that carry structured filters as a + JSON string. Applies size and count limits before touching the payload so + an attacker cannot exhaust memory or CPU with a crafted input. + + Returns an empty list when *raw* is ``None`` or empty. + + Raises :class:`ValueError` on any validation failure — callers should catch + this and convert to an appropriate HTTP error. + """ + if not raw: + return [] + + if len(raw) > max_length: + raise ValueError("conditions payload exceeds size limit") + + try: + items = json.loads(raw) + except (json.JSONDecodeError, ValueError) as exc: + raise ValueError("conditions is not valid JSON") from exc + + if not isinstance(items, list): + raise ValueError("conditions must be a JSON array") + + if len(items) > max_conditions: + raise ValueError(f"too many conditions (max {max_conditions})") + + try: + return [FilterCondition(**item) for item in items] + except (ValidationError, TypeError) as exc: + raise ValueError("invalid condition structure") from exc + + +def parse_sort_fields( + raw: str | None, + *, + max_fields: int = _MAX_SORT_FIELDS, + max_length: int = _MAX_RAW_LENGTH, +) -> list[SortField]: + """Safely parse a JSON-encoded list of sort fields. + + Mirrors :func:`parse_conditions` with the same security hardening. + Returns an empty list when *raw* is ``None`` or empty. + + Raises :class:`ValueError` on any validation failure. + """ + if not raw: + return [] + + if len(raw) > max_length: + raise ValueError("sort fields payload exceeds size limit") + + try: + items = json.loads(raw) + except (json.JSONDecodeError, ValueError) as exc: + raise ValueError("sorting is not valid JSON") from exc + + if not isinstance(items, list): + raise ValueError("sorting must be a JSON array") + + if len(items) > max_fields: + raise ValueError(f"too many sort fields (max {max_fields})") + + try: + return [SortField(**item) for item in items] + except (ValidationError, TypeError) as exc: + raise ValueError("invalid sort field structure") from exc + + +def extract_condition_value( + conditions: list[FilterCondition], + field: str, +) -> Any: + """Return the ``value`` for the first condition matching *field*, or ``None``.""" + for cond in conditions: + if cond.field == field: + return cond.value + return None + + +def apply_filters( + statement: Select, + model: Any, + conditions: list[FilterCondition | FilterGroup], + allowed_fields: dict[str, Any] | None = None, +) -> Select: + """Apply filter conditions to a Select statement. + + ``conditions`` can contain flat :class:`FilterCondition` items (implicitly + AND-ed) or :class:`FilterGroup` items for explicit AND/OR logic. + + Both :class:`FilterCondition` and :class:`FilterGroup` support a ``negate`` + flag that wraps the resulting clause in ``NOT(...)``. + + ``allowed_fields`` maps field names to SQLAlchemy column expressions **or + callables**. A callable value receives ``(op, value)`` and must return a + SA clause element (or *None* to skip). Negation is still handled + uniformly by ``_resolve_condition``. + + If *None*, uses ``getattr(model, field)`` directly. + Unknown fields are silently skipped (defense in depth). + """ + for cond in conditions: + clause = _resolve_condition(cond, model, allowed_fields) + if clause is not None: + statement = statement.where(clause) + + return statement + + +def _resolve_condition( + cond: FilterCondition | FilterGroup, + model: Any, + allowed_fields: dict[str, Any] | None, +): + """Recursively resolve a condition or group into a SA clause element.""" + if isinstance(cond, FilterGroup): + return _resolve_group(cond, model, allowed_fields) + + # Leaf FilterCondition + if allowed_fields is not None: + col_or_handler = allowed_fields.get(cond.field) + else: + col_or_handler = getattr(model, cond.field, None) + + if col_or_handler is None: + return None + + if callable(col_or_handler): + clause = col_or_handler(cond.op, cond.value) + else: + clause = _build_filter_clause(col_or_handler, cond.op, cond.value) + + if clause is None: + return None + + return not_(clause) if cond.negate else clause + + +def _resolve_group( + group: FilterGroup, + model: Any, + allowed_fields: dict[str, Any] | None, +): + """Resolve a FilterGroup into a SA and_()/or_() expression, optionally negated.""" + clauses = [] + for cond in group.conditions: + clause = _resolve_condition(cond, model, allowed_fields) + if clause is not None: + clauses.append(clause) + + if not clauses: + return None + + if len(clauses) == 1: + combined = clauses[0] + elif group.logic == "or": + combined = or_(*clauses) + else: + combined = and_(*clauses) + + return not_(combined) if group.negate else combined + + +def _build_filter_clause(col: Any, op: FilterOp, value: Any): + """Return a single WHERE clause for *col* with the given operator. + + Negation is handled by the caller via ``FilterCondition.negate``, + not by separate operators. + """ + if op == FilterOp.eq: + return col == value + if op == FilterOp.lt: + return col < value + if op == FilterOp.lte: + return col <= value + if op == FilterOp.gt: + return col > value + if op == FilterOp.gte: + return col >= value + if op == FilterOp.in_: + if not value: + return None + return col.in_(tuple(value) if not isinstance(value, tuple) else value) + if op == FilterOp.ilike: + return col.ilike(f"%{value}%") + if op == FilterOp.is_null: + return col.is_(None) if value else col.is_not(None) + return None + + +def apply_sorting( + statement: Select, + model: Any, + sort_fields: list[SortField] | None = None, + allowed_fields: dict[str, Any] | None = None, + default_sort: list[tuple[Any, str]] | None = None, + *, + sort_by: str | None = None, + sort_dir: str | None = None, +) -> Select: + """Apply ORDER BY clauses to a Select statement. + + Accepts either structured ``sort_fields`` or comma-separated ``sort_by``/``sort_dir`` + strings (the current endpoint convention). When both are provided, ``sort_fields`` wins. + + ``allowed_fields`` maps field names to SA column expressions. If *None*, uses + ``getattr(model, field)`` directly. + + ``default_sort`` is applied when no valid sort fields are found. + """ + fields_to_apply = _resolve_sort_fields(sort_fields, sort_by, sort_dir) + + has_valid = False + for sf in fields_to_apply: + if allowed_fields is not None: + col = allowed_fields.get(sf.field) + else: + col = getattr(model, sf.field, None) + + if col is None: + continue + + order = desc(col) if sf.dir == SortDir.desc else asc(col) + statement = statement.order_by(order.nulls_last()) + has_valid = True + + if not has_valid and default_sort: + for col, direction in default_sort: + order = desc(col) if direction == "desc" else asc(col) + statement = statement.order_by(order) + return statement + + if has_valid: + # Add model PK as tiebreaker if model has an 'id' attribute + pk = getattr(model, "id", None) + if pk is not None: + statement = statement.order_by(asc(pk)) + + return statement + + +def _resolve_sort_fields( + sort_fields: list[SortField] | None, + sort_by: str | None, + sort_dir: str | None, +) -> list[SortField]: + """Convert comma-separated sort_by/sort_dir into SortField list, or use sort_fields.""" + if sort_fields: + return sort_fields + + if not sort_by: + return [] + + fields = [f.strip() for f in sort_by.split(",") if f.strip()] + dirs = [d.strip() for d in (sort_dir or "").split(",")] + + result = [] + for i, field_name in enumerate(fields): + direction = dirs[i] if i < len(dirs) else "asc" + try: + dir_enum = SortDir(direction) + except ValueError: + dir_enum = SortDir.asc + result.append(SortField(field=field_name, dir=dir_enum)) + + return result + + +def apply_pagination( + statement: Select, + page: int = 1, + page_size: int = 20, +) -> Select: + """Apply OFFSET/LIMIT. ``page_size=0`` means no pagination (all rows).""" + if page_size <= 0: + return statement + return statement.offset((page - 1) * page_size).limit(page_size) + + +def _clamp_page(page: int, page_size: int, total_count: int) -> int: + """Reset page to 1 if it overshoots the available results. + + This handles the case where filters/sort changed and the current page + no longer exists (e.g. user was on page 5, but new filters only yield 2 pages). + """ + if page_size <= 0: + return 1 + if total_count == 0: + return 1 + import math + total_pages = math.ceil(total_count / page_size) + if page > total_pages: + return 1 + return page + + +async def paginated_query( + session: AsyncSession, + data_stmt: Select, + count_stmt: Select, + page: int = 1, + page_size: int = 20, +) -> tuple[list, int, int]: + """Execute count + data queries with automatic page clamping. + + Returns ``(items, total_count, actual_page)`` where *actual_page* may + differ from the requested *page* if it overshot the result set. + """ + total_count = (await session.exec(count_stmt)).one() + + page = _clamp_page(page, page_size, total_count) + + data_stmt = apply_pagination(data_stmt, page, page_size) + result = await session.exec(data_stmt) + items = list(result.all()) + + return items, total_count, page + + +def build_paginated_response( + items: list, + total_count: int, + page: int, + page_size: int, + **extra: Any, +) -> dict: + """Build a dict suitable for unpacking into a concrete response model. + + Computes ``has_next`` and ``has_prev`` automatically from the inputs. + """ + if page_size <= 0: + effective_page = 1 + has_next = False + has_prev = False + else: + effective_page = page + has_next = page * page_size < total_count + has_prev = page > 1 + + return { + "items": items, + "total_count": total_count, + "page": effective_page, + "page_size": page_size, + "has_next": has_next, + "has_prev": has_prev, + **extra, + } diff --git a/backend/app/db/query_test.py b/backend/app/db/query_test.py new file mode 100644 index 00000000..3cf80061 --- /dev/null +++ b/backend/app/db/query_test.py @@ -0,0 +1,959 @@ +"""Unit tests for the reusable query utility functions.""" + +import pytest +from sqlalchemy import Column, Integer, MetaData, String, Boolean, Float, Table +from sqlmodel import select + +from app.db.query import ( + _clamp_page, + apply_filters, + apply_sorting, + apply_pagination, + build_paginated_response, + extract_condition_value, + parse_conditions, + parse_sort_fields, +) +from app.schemas.query import FilterCondition, FilterGroup, FilterOp, SortField, SortDir + + +# --------------------------------------------------------------------------- +# Dummy table for testing (not persisted — we only inspect generated SQL) +# --------------------------------------------------------------------------- + +_test_metadata = MetaData() +_dummy_table = Table( + "dummy_items", + _test_metadata, + Column("id", Integer, primary_key=True), + Column("name", String(100)), + Column("priority", String(20)), + Column("score", Float), + Column("is_active", Boolean), +) + + +class _DummyModel: + """Attribute-access wrapper around the dummy table columns.""" + id = _dummy_table.c.id + name = _dummy_table.c.name + priority = _dummy_table.c.priority + score = _dummy_table.c.score + is_active = _dummy_table.c.is_active + + +# --------------------------------------------------------------------------- +# apply_filters +# --------------------------------------------------------------------------- + +class TestApplyFilters: + """Tests for apply_filters.""" + + def test_eq_filter(self): + stmt = select(_dummy_table) + conditions = [FilterCondition(field="name", op=FilterOp.eq, value="alice")] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "name = 'alice'" in sql + + def test_negate_eq(self): + """negate=True on eq negates the comparison.""" + stmt = select(_dummy_table) + conditions = [FilterCondition(field="name", op=FilterOp.eq, value="bob", negate=True)] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + # SA optimizes NOT(x = y) into x != y + assert "name != 'bob'" in sql + + def test_negate_in(self): + """negate=True on in negates the IN clause.""" + stmt = select(_dummy_table) + conditions = [FilterCondition(field="priority", op=FilterOp.in_, value=["low", "medium"], negate=True)] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})).upper() + # SA may render as NOT IN or NOT (... IN ...) + assert "NOT" in sql or "NOT IN" in sql + assert "'LOW'" in sql + + def test_negate_gt(self): + """negate=True on gt negates to <= (SA optimizes).""" + stmt = select(_dummy_table) + conditions = [FilterCondition(field="score", op=FilterOp.gt, value=5, negate=True)] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + # SA optimizes NOT(score > 5) into score <= 5 + assert "score <= 5" in sql + + def test_negate_false_is_normal(self): + """negate=False (default) behaves like a normal filter.""" + stmt = select(_dummy_table) + conditions = [FilterCondition(field="name", op=FilterOp.eq, value="alice", negate=False)] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "name = 'alice'" in sql + + def test_lt_lte_gt_gte_filters(self): + stmt = select(_dummy_table) + conditions = [ + FilterCondition(field="score", op=FilterOp.gt, value=5), + FilterCondition(field="score", op=FilterOp.lte, value=100), + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "score > 5" in sql + assert "score <= 100" in sql + + def test_in_filter(self): + stmt = select(_dummy_table) + conditions = [FilterCondition(field="priority", op=FilterOp.in_, value=["high", "urgent"])] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "IN" in sql + assert "'high'" in sql + assert "'urgent'" in sql + + def test_in_filter_empty_list_skipped(self): + """An in_ filter with an empty list should produce no WHERE clause.""" + stmt = select(_dummy_table) + conditions = [FilterCondition(field="priority", op=FilterOp.in_, value=[])] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "WHERE" not in sql + + def test_ilike_filter(self): + stmt = select(_dummy_table) + conditions = [FilterCondition(field="name", op=FilterOp.ilike, value="test")] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})).lower() + assert "ilike" in sql or "like" in sql + + def test_is_null_true(self): + stmt = select(_dummy_table) + conditions = [FilterCondition(field="name", op=FilterOp.is_null, value=True)] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "IS NULL" in sql.upper() + + def test_is_null_false(self): + stmt = select(_dummy_table) + conditions = [FilterCondition(field="name", op=FilterOp.is_null, value=False)] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "IS NOT NULL" in sql.upper() + + def test_unknown_field_skipped(self): + """Fields not in allowed_fields should be silently ignored.""" + stmt = select(_dummy_table) + conditions = [FilterCondition(field="nonexistent", op=FilterOp.eq, value="x")] + allowed = {"name": _DummyModel.name} + result = apply_filters(stmt, _DummyModel, conditions, allowed_fields=allowed) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "WHERE" not in sql + + def test_allowed_fields_whitelist(self): + """Only fields in the whitelist should be applied.""" + stmt = select(_dummy_table) + conditions = [ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + FilterCondition(field="priority", op=FilterOp.eq, value="high"), + ] + allowed = {"name": _DummyModel.name} # priority not allowed + result = apply_filters(stmt, _DummyModel, conditions, allowed_fields=allowed) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "name = 'alice'" in sql + # priority should appear in SELECT but not in WHERE + where_clause = sql.split("WHERE", 1)[1] if "WHERE" in sql else "" + assert "priority" not in where_clause + + def test_multiple_conditions(self): + stmt = select(_dummy_table) + conditions = [ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + FilterCondition(field="score", op=FilterOp.gte, value=10), + FilterCondition(field="is_active", op=FilterOp.eq, value=True), + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "name" in sql + assert "score" in sql + assert "is_active" in sql + + def test_no_conditions_returns_original(self): + stmt = select(_dummy_table) + result = apply_filters(stmt, _DummyModel, []) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "WHERE" not in sql + + +# --------------------------------------------------------------------------- +# FilterGroup (and / or) +# --------------------------------------------------------------------------- + +class TestFilterGroup: + """Tests for AND/OR grouping via FilterGroup.""" + + def test_or_same_field(self): + """OR two values for the same field: name = 'alice' OR name = 'bob'.""" + stmt = select(_dummy_table) + conditions = [ + FilterGroup( + logic="or", + conditions=[ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + FilterCondition(field="name", op=FilterOp.eq, value="bob"), + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "OR" in sql + assert "'alice'" in sql + assert "'bob'" in sql + + def test_and_group(self): + """Explicit AND group: name = 'alice' AND score > 5.""" + stmt = select(_dummy_table) + conditions = [ + FilterGroup( + logic="and", + conditions=[ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + FilterCondition(field="score", op=FilterOp.gt, value=5), + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "name = 'alice'" in sql + assert "score > 5" in sql + + def test_or_with_allowed_fields(self): + """OR group respects allowed_fields whitelist.""" + stmt = select(_dummy_table) + allowed = {"name": _DummyModel.name} + conditions = [ + FilterGroup( + logic="or", + conditions=[ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + FilterCondition(field="score", op=FilterOp.gt, value=5), # not allowed + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions, allowed_fields=allowed) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + # Only name should be in the WHERE clause, score is filtered out + assert "'alice'" in sql + where_clause = sql.split("WHERE", 1)[1] if "WHERE" in sql else "" + assert "score" not in where_clause + + def test_nested_groups(self): + """Nested: is_active = true AND (name = 'alice' OR name = 'bob').""" + stmt = select(_dummy_table) + conditions = [ + FilterCondition(field="is_active", op=FilterOp.eq, value=True), + FilterGroup( + logic="or", + conditions=[ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + FilterCondition(field="name", op=FilterOp.eq, value="bob"), + ], + ), + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "is_active" in sql + assert "OR" in sql + assert "'alice'" in sql + assert "'bob'" in sql + + def test_or_three_values(self): + """OR across three values for the same field.""" + stmt = select(_dummy_table) + conditions = [ + FilterGroup( + logic="or", + conditions=[ + FilterCondition(field="priority", op=FilterOp.eq, value="low"), + FilterCondition(field="priority", op=FilterOp.eq, value="medium"), + FilterCondition(field="priority", op=FilterOp.eq, value="high"), + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "OR" in sql + assert "'low'" in sql + assert "'medium'" in sql + assert "'high'" in sql + + def test_or_different_ops(self): + """OR with different operators: score > 90 OR score IS NULL.""" + stmt = select(_dummy_table) + conditions = [ + FilterGroup( + logic="or", + conditions=[ + FilterCondition(field="score", op=FilterOp.gt, value=90), + FilterCondition(field="score", op=FilterOp.is_null, value=True), + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})).upper() + assert "OR" in sql + assert "SCORE > 90" in sql + assert "IS NULL" in sql + + def test_empty_group_skipped(self): + """A group with no valid conditions should not add a WHERE clause.""" + stmt = select(_dummy_table) + allowed = {"name": _DummyModel.name} + conditions = [ + FilterGroup( + logic="or", + conditions=[ + FilterCondition(field="unknown1", op=FilterOp.eq, value="x"), + FilterCondition(field="unknown2", op=FilterOp.eq, value="y"), + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions, allowed_fields=allowed) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "WHERE" not in sql + + def test_single_condition_group_unwrapped(self): + """A group with one valid condition should not wrap in AND/OR.""" + stmt = select(_dummy_table) + conditions = [ + FilterGroup( + logic="or", + conditions=[ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "name = 'alice'" in sql + assert "OR" not in sql + + def test_negate_or_group(self): + """NOT (status = 'archived' OR status = 'deleted').""" + stmt = select(_dummy_table) + conditions = [ + FilterGroup( + logic="or", + negate=True, + conditions=[ + FilterCondition(field="name", op=FilterOp.eq, value="archived"), + FilterCondition(field="name", op=FilterOp.eq, value="deleted"), + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})).upper() + assert "NOT" in sql + assert "OR" in sql + assert "'ARCHIVED'" in sql + assert "'DELETED'" in sql + + def test_negate_and_group(self): + """NOT (name = 'alice' AND score > 5) — negate an AND group.""" + stmt = select(_dummy_table) + conditions = [ + FilterGroup( + logic="and", + negate=True, + conditions=[ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + FilterCondition(field="score", op=FilterOp.gt, value=5), + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})).upper() + assert "NOT" in sql + assert "'ALICE'" in sql + assert "SCORE > 5" in sql + + def test_negate_single_condition_group(self): + """NOT (name = 'alice') via a negated group with one condition.""" + stmt = select(_dummy_table) + conditions = [ + FilterGroup( + logic="or", + negate=True, + conditions=[ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + # SA optimizes NOT(name = 'alice') into name != 'alice' + assert "name != 'alice'" in sql + + def test_negate_false_group_is_normal(self): + """negate=False (default) on a group behaves normally.""" + stmt = select(_dummy_table) + conditions = [ + FilterGroup( + logic="or", + negate=False, + conditions=[ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + FilterCondition(field="name", op=FilterOp.eq, value="bob"), + ], + ) + ] + result = apply_filters(stmt, _DummyModel, conditions) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "OR" in sql + assert "NOT" not in sql.upper() + + +# --------------------------------------------------------------------------- +# Callable filter handlers +# --------------------------------------------------------------------------- + +class TestCallableFilterHandler: + """Tests for callable handler support in allowed_fields.""" + + def test_callable_handler_basic(self): + """A callable handler returning an IN clause is applied.""" + def status_handler(op, value): + return _dummy_table.c.priority.in_(tuple(value)) + + stmt = select(_dummy_table) + conditions = [FilterCondition(field="status_category", op=FilterOp.in_, value=["active", "done"])] + allowed = {"status_category": status_handler} + result = apply_filters(stmt, _DummyModel, conditions, allowed_fields=allowed) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "IN" in sql + assert "'active'" in sql + assert "'done'" in sql + + def test_callable_handler_negate(self): + """negate=True wraps the handler result in NOT.""" + def handler(op, value): + return _dummy_table.c.name == value + + stmt = select(_dummy_table) + conditions = [FilterCondition(field="custom", op=FilterOp.eq, value="alice", negate=True)] + allowed = {"custom": handler} + result = apply_filters(stmt, _DummyModel, conditions, allowed_fields=allowed) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "name != 'alice'" in sql + + def test_callable_handler_returns_none_skipped(self): + """When handler returns None, no WHERE clause is added.""" + def handler(op, value): + return None + + stmt = select(_dummy_table) + conditions = [FilterCondition(field="custom", op=FilterOp.eq, value="x")] + allowed = {"custom": handler} + result = apply_filters(stmt, _DummyModel, conditions, allowed_fields=allowed) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "WHERE" not in sql + + def test_callable_handler_in_or_group(self): + """Callable inside a FilterGroup with OR logic.""" + def handler(op, value): + return _dummy_table.c.score > value + + stmt = select(_dummy_table) + conditions = [ + FilterGroup( + logic="or", + conditions=[ + FilterCondition(field="name", op=FilterOp.eq, value="alice"), + FilterCondition(field="high_score", op=FilterOp.gt, value=90), + ], + ) + ] + allowed = {"name": _DummyModel.name, "high_score": handler} + result = apply_filters(stmt, _DummyModel, conditions, allowed_fields=allowed) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "OR" in sql + assert "'alice'" in sql + assert "score > 90" in sql + + def test_callable_mixed_with_column(self): + """Dict with both column refs and callables works together.""" + def handler(op, value): + return _dummy_table.c.score >= value + + stmt = select(_dummy_table) + conditions = [ + FilterCondition(field="name", op=FilterOp.eq, value="bob"), + FilterCondition(field="min_score", op=FilterOp.gte, value=50), + ] + allowed = {"name": _DummyModel.name, "min_score": handler} + result = apply_filters(stmt, _DummyModel, conditions, allowed_fields=allowed) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "name = 'bob'" in sql + assert "score >= 50" in sql + + def test_callable_receives_correct_args(self): + """Handler receives the exact op and value from the FilterCondition.""" + received = {} + + def handler(op, value): + received["op"] = op + received["value"] = value + return _dummy_table.c.id == 1 # dummy clause + + stmt = select(_dummy_table) + conditions = [FilterCondition(field="custom", op=FilterOp.in_, value=[10, 20])] + allowed = {"custom": handler} + apply_filters(stmt, _DummyModel, conditions, allowed_fields=allowed) + assert received["op"] == FilterOp.in_ + assert received["value"] == [10, 20] + + +# --------------------------------------------------------------------------- +# apply_sorting +# --------------------------------------------------------------------------- + +class TestApplySorting: + """Tests for apply_sorting.""" + + def test_sort_by_string_asc(self): + stmt = select(_dummy_table) + result = apply_sorting(stmt, _DummyModel, sort_by="name", sort_dir="asc") + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "ORDER BY" in sql + assert "name" in sql + + def test_sort_by_string_desc(self): + stmt = select(_dummy_table) + result = apply_sorting(stmt, _DummyModel, sort_by="score", sort_dir="desc") + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "ORDER BY" in sql + assert "DESC" in sql + + def test_multi_sort_by_string(self): + stmt = select(_dummy_table) + result = apply_sorting(stmt, _DummyModel, sort_by="name,score", sort_dir="asc,desc") + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "ORDER BY" in sql + assert "name" in sql + assert "score" in sql + + def test_sort_fields_structured(self): + stmt = select(_dummy_table) + fields = [ + SortField(field="name", dir=SortDir.asc), + SortField(field="score", dir=SortDir.desc), + ] + result = apply_sorting(stmt, _DummyModel, sort_fields=fields) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "ORDER BY" in sql + assert "name" in sql + assert "score" in sql + + def test_default_sort_used_when_no_fields(self): + stmt = select(_dummy_table) + default = [(_DummyModel.score, "desc"), (_DummyModel.id, "asc")] + result = apply_sorting(stmt, _DummyModel, default_sort=default) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "ORDER BY" in sql + assert "score" in sql + + def test_default_sort_not_used_when_fields_provided(self): + stmt = select(_dummy_table) + default = [(_DummyModel.score, "desc")] + result = apply_sorting( + stmt, _DummyModel, + sort_by="name", sort_dir="asc", + default_sort=default, + ) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "name" in sql + + def test_unknown_sort_field_falls_back_to_default(self): + stmt = select(_dummy_table) + allowed = {"name": _DummyModel.name} + default = [(_DummyModel.id, "asc")] + result = apply_sorting( + stmt, _DummyModel, + sort_by="nonexistent", sort_dir="asc", + allowed_fields=allowed, + default_sort=default, + ) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "ORDER BY" in sql + + def test_id_tiebreaker_added(self): + stmt = select(_dummy_table) + result = apply_sorting(stmt, _DummyModel, sort_by="name", sort_dir="asc") + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + # Should have id as tiebreaker at the end + assert "id" in sql + + def test_no_sort_no_default_returns_unmodified(self): + stmt = select(_dummy_table) + result = apply_sorting(stmt, _DummyModel) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "ORDER BY" not in sql + + +# --------------------------------------------------------------------------- +# apply_pagination +# --------------------------------------------------------------------------- + +class TestApplyPagination: + """Tests for apply_pagination.""" + + def test_first_page(self): + stmt = select(_dummy_table) + result = apply_pagination(stmt, page=1, page_size=20) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "LIMIT" in sql + assert "OFFSET" in sql + + def test_second_page(self): + stmt = select(_dummy_table) + result = apply_pagination(stmt, page=2, page_size=10) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "LIMIT" in sql + assert "10" in sql + + def test_page_size_zero_no_pagination(self): + stmt = select(_dummy_table) + result = apply_pagination(stmt, page=1, page_size=0) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "LIMIT" not in sql + assert "OFFSET" not in sql + + def test_negative_page_size_no_pagination(self): + stmt = select(_dummy_table) + result = apply_pagination(stmt, page=1, page_size=-1) + sql = str(result.compile(compile_kwargs={"literal_binds": True})) + assert "LIMIT" not in sql + + +# --------------------------------------------------------------------------- +# build_paginated_response +# --------------------------------------------------------------------------- + +class TestBuildPaginatedResponse: + """Tests for build_paginated_response helper.""" + + def test_basic_response(self): + result = build_paginated_response( + items=["a", "b", "c"], + total_count=50, + page=2, + page_size=20, + ) + assert result["items"] == ["a", "b", "c"] + assert result["total_count"] == 50 + assert result["page"] == 2 + assert result["page_size"] == 20 + assert result["has_next"] is True + assert result["has_prev"] is True + + def test_first_page_no_prev(self): + result = build_paginated_response( + items=["a", "b"], + total_count=5, + page=1, + page_size=3, + ) + assert result["has_next"] is True + assert result["has_prev"] is False + + def test_last_page_no_next(self): + result = build_paginated_response( + items=["e"], + total_count=5, + page=2, + page_size=4, + ) + assert result["has_next"] is False + assert result["has_prev"] is True + + def test_single_page(self): + result = build_paginated_response( + items=["a", "b"], + total_count=2, + page=1, + page_size=20, + ) + assert result["has_next"] is False + assert result["has_prev"] is False + + def test_extra_fields(self): + result = build_paginated_response( + items=[], + total_count=0, + page=1, + page_size=20, + sort_by="name", + sort_dir="asc", + ) + assert result["sort_by"] == "name" + assert result["sort_dir"] == "asc" + assert result["has_next"] is False + assert result["has_prev"] is False + + def test_page_size_zero_resets_page(self): + result = build_paginated_response( + items=["a"], + total_count=1, + page=5, + page_size=0, + ) + assert result["page"] == 1 + assert result["has_next"] is False + assert result["has_prev"] is False + + +# --------------------------------------------------------------------------- +# _clamp_page +# --------------------------------------------------------------------------- + +class TestClampPage: + """Tests for page clamping when page overshoots results.""" + + def test_valid_page_unchanged(self): + assert _clamp_page(2, 10, 50) == 2 + + def test_page_beyond_total_resets_to_1(self): + # 50 items, 20 per page = 3 pages. Page 5 is out of range. + assert _clamp_page(5, 20, 50) == 1 + + def test_last_page_is_valid(self): + # 50 items, 20 per page = 3 pages. Page 3 is valid. + assert _clamp_page(3, 20, 50) == 3 + + def test_zero_total_resets_to_1(self): + assert _clamp_page(3, 20, 0) == 1 + + def test_page_size_zero_returns_1(self): + assert _clamp_page(5, 0, 100) == 1 + + def test_page_1_always_valid(self): + assert _clamp_page(1, 20, 1) == 1 + + def test_exact_boundary(self): + # 20 items, 20 per page = 1 page. Page 1 is valid, page 2 is not. + assert _clamp_page(1, 20, 20) == 1 + assert _clamp_page(2, 20, 20) == 1 + + +# --------------------------------------------------------------------------- +# Schema validation +# --------------------------------------------------------------------------- + +class TestSchemas: + """Tests for query schema validation.""" + + def test_pagination_params_defaults(self): + from app.schemas.query import PaginationParams + params = PaginationParams() + assert params.page == 1 + assert params.page_size == 20 + + def test_pagination_params_validation(self): + from app.schemas.query import PaginationParams + with pytest.raises(Exception): + PaginationParams(page=0) # ge=1 + + def test_pagination_params_max_page_size(self): + from app.schemas.query import PaginationParams + with pytest.raises(Exception): + PaginationParams(page_size=101) # le=100 + + def test_filter_condition_defaults(self): + cond = FilterCondition(field="name", value="test") + assert cond.op == FilterOp.eq + assert cond.negate is False + + def test_sort_field_defaults(self): + sf = SortField(field="name") + assert sf.dir == SortDir.asc + + +# --------------------------------------------------------------------------- +# parse_conditions (security & validation) +# --------------------------------------------------------------------------- + +class TestParseConditions: + """Tests for parse_conditions including security hardening.""" + + def test_none_returns_empty(self): + assert parse_conditions(None) == [] + + def test_empty_string_returns_empty(self): + assert parse_conditions("") == [] + + def test_valid_single_condition(self): + raw = '[{"field": "name", "op": "eq", "value": "alice"}]' + result = parse_conditions(raw) + assert len(result) == 1 + assert result[0].field == "name" + assert result[0].op == FilterOp.eq + assert result[0].value == "alice" + + def test_valid_multiple_conditions(self): + raw = '[{"field": "a", "op": "eq", "value": 1}, {"field": "b", "op": "gt", "value": 5}]' + result = parse_conditions(raw) + assert len(result) == 2 + + def test_defaults_applied(self): + """op defaults to eq, negate defaults to False.""" + raw = '[{"field": "name", "value": "test"}]' + result = parse_conditions(raw) + assert result[0].op == FilterOp.eq + assert result[0].negate is False + + def test_negate_preserved(self): + raw = '[{"field": "name", "op": "eq", "value": "x", "negate": true}]' + result = parse_conditions(raw) + assert result[0].negate is True + + def test_rejects_oversized_payload(self): + with pytest.raises(ValueError, match="size limit"): + parse_conditions("x" * 10_001) + + def test_custom_max_length(self): + with pytest.raises(ValueError, match="size limit"): + parse_conditions("x" * 101, max_length=100) + + def test_rejects_invalid_json(self): + with pytest.raises(ValueError, match="not valid JSON"): + parse_conditions("{not json}") + + def test_rejects_non_array(self): + with pytest.raises(ValueError, match="must be a JSON array"): + parse_conditions('{"field": "name"}') + + def test_rejects_too_many_conditions(self): + items = [{"field": "f", "value": i} for i in range(51)] + import json + with pytest.raises(ValueError, match="too many conditions"): + parse_conditions(json.dumps(items)) + + def test_custom_max_conditions(self): + items = [{"field": "f", "value": i} for i in range(3)] + import json + with pytest.raises(ValueError, match="too many conditions"): + parse_conditions(json.dumps(items), max_conditions=2) + + def test_rejects_invalid_structure(self): + with pytest.raises(ValueError, match="invalid condition structure"): + parse_conditions('[{"bad_key": "value"}]') + + def test_rejects_invalid_op(self): + with pytest.raises(ValueError, match="invalid condition structure"): + parse_conditions('[{"field": "name", "op": "DROP TABLE", "value": "x"}]') + + def test_at_exact_limit_succeeds(self): + items = [{"field": "f", "value": i} for i in range(50)] + import json + result = parse_conditions(json.dumps(items)) + assert len(result) == 50 + + +# --------------------------------------------------------------------------- +# extract_condition_value +# --------------------------------------------------------------------------- + +class TestExtractConditionValue: + """Tests for extract_condition_value helper.""" + + def test_finds_matching_field(self): + conditions = [ + FilterCondition(field="status", op=FilterOp.eq, value="active"), + FilterCondition(field="priority", op=FilterOp.in_, value=["high"]), + ] + assert extract_condition_value(conditions, "priority") == ["high"] + + def test_returns_first_match(self): + conditions = [ + FilterCondition(field="name", value="first"), + FilterCondition(field="name", value="second"), + ] + assert extract_condition_value(conditions, "name") == "first" + + def test_returns_none_when_not_found(self): + conditions = [FilterCondition(field="name", value="test")] + assert extract_condition_value(conditions, "missing") is None + + def test_empty_list_returns_none(self): + assert extract_condition_value([], "anything") is None + + +# --------------------------------------------------------------------------- +# parse_sort_fields (security & validation) +# --------------------------------------------------------------------------- + +class TestParseSortFields: + """Tests for parse_sort_fields including security hardening.""" + + def test_none_returns_empty(self): + assert parse_sort_fields(None) == [] + + def test_empty_string_returns_empty(self): + assert parse_sort_fields("") == [] + + def test_valid_single_field(self): + raw = '[{"field": "due_date", "dir": "desc"}]' + result = parse_sort_fields(raw) + assert len(result) == 1 + assert result[0].field == "due_date" + assert result[0].dir == SortDir.desc + + def test_valid_multiple_fields(self): + raw = '[{"field": "date_group", "dir": "asc"}, {"field": "due_date", "dir": "desc"}]' + result = parse_sort_fields(raw) + assert len(result) == 2 + assert result[0].field == "date_group" + assert result[1].dir == SortDir.desc + + def test_defaults_dir_to_asc(self): + raw = '[{"field": "title"}]' + result = parse_sort_fields(raw) + assert result[0].dir == SortDir.asc + + def test_rejects_oversized_payload(self): + with pytest.raises(ValueError, match="size limit"): + parse_sort_fields("x" * 10_001) + + def test_rejects_invalid_json(self): + with pytest.raises(ValueError, match="not valid JSON"): + parse_sort_fields("{not json}") + + def test_rejects_non_array(self): + with pytest.raises(ValueError, match="must be a JSON array"): + parse_sort_fields('{"field": "name"}') + + def test_rejects_too_many_fields(self): + items = [{"field": f"f{i}"} for i in range(11)] + import json + with pytest.raises(ValueError, match="too many sort fields"): + parse_sort_fields(json.dumps(items)) + + def test_rejects_invalid_structure(self): + with pytest.raises(ValueError, match="invalid sort field structure"): + parse_sort_fields('[{"bad_key": "value"}]') + + def test_rejects_invalid_dir(self): + with pytest.raises(ValueError, match="invalid sort field structure"): + parse_sort_fields('[{"field": "name", "dir": "RANDOM"}]') + + def test_at_exact_limit_succeeds(self): + items = [{"field": f"f{i}"} for i in range(10)] + import json + result = parse_sort_fields(json.dumps(items)) + assert len(result) == 10 + + def test_custom_max_fields(self): + items = [{"field": f"f{i}"} for i in range(3)] + import json + with pytest.raises(ValueError, match="too many sort fields"): + parse_sort_fields(json.dumps(items), max_fields=2) diff --git a/backend/app/main.py b/backend/app/main.py index 39063fd2..2bd60769 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -96,6 +96,64 @@ async def serve_spa(full_path: str) -> FileResponse: raise HTTPException(status_code=404, detail="SPA bundle not found") +def _inject_query_schemas(openapi_schema: dict) -> None: + """Inject shared query filter/sort schemas into OpenAPI components. + + These schemas (FilterCondition, FilterOp, FilterGroup, SortField, SortDir) + are defined in ``app.schemas.query`` and used by list endpoints that accept + a ``conditions`` JSON query parameter. Injecting them here lets Orval + auto-generate TypeScript types so the frontend never hand-defines them. + """ + from app.schemas.query import ( + FilterCondition, + FilterGroup, + FilterOp, + SortDir, + SortField, + ) + + schemas = openapi_schema.setdefault("components", {}).setdefault("schemas", {}) + + for model in (FilterCondition, FilterGroup, SortField): + full = model.model_json_schema( + ref_template="#/components/schemas/{model}", + ) + defs = full.pop("$defs", {}) + # For self-referencing models (e.g. FilterGroup) the top level is + # just {"$ref": "..."} and the real schema lives in $defs. + if "$ref" in full and not full.get("properties"): + real = defs.pop(model.__name__, full) + schemas[model.__name__] = real + else: + schemas[model.__name__] = full + for name, sub_schema in defs.items(): + schemas.setdefault(name, sub_schema) + + # Enums as standalone schemas (may already be added via $defs above) + for enum_cls in (FilterOp, SortDir): + schemas.setdefault( + enum_cls.__name__, + {"title": enum_cls.__name__, "type": "string", "enum": [e.value for e in enum_cls]}, + ) + + # Override query parameters to expose their real types instead of the raw + # ``string`` that FastAPI infers from the endpoint signature. The Axios + # paramsSerializer on the frontend JSON-encodes arrays of objects automatically. + fc_ref = {"$ref": "#/components/schemas/FilterCondition"} + sf_ref = {"$ref": "#/components/schemas/SortField"} + for path_item in openapi_schema.get("paths", {}).values(): + for operation in path_item.values(): + if not isinstance(operation, dict): + continue + for param in operation.get("parameters", []): + if param.get("name") == "conditions" and param.get("in") == "query": + param["schema"] = {"type": "array", "items": fc_ref} + param.pop("anyOf", None) + if param.get("name") == "sorting" and param.get("in") == "query": + param["schema"] = {"type": "array", "items": sf_ref} + param.pop("anyOf", None) + + def custom_openapi() -> dict: if app.openapi_schema: return app.openapi_schema @@ -118,6 +176,8 @@ def custom_openapi() -> dict: }, ) + _inject_query_schemas(openapi_schema) + for path_item in openapi_schema.get("paths", {}).values(): for operation in path_item.values(): if not isinstance(operation, dict): diff --git a/backend/app/schemas/query.py b/backend/app/schemas/query.py new file mode 100644 index 00000000..e15fb2f1 --- /dev/null +++ b/backend/app/schemas/query.py @@ -0,0 +1,108 @@ +"""Shared query schemas for filtering, sorting, and pagination.""" + +from __future__ import annotations + +from enum import Enum +from typing import Any, Generic, Literal, TypeVar + +from pydantic import BaseModel, Field + + +MAX_PAGE_SIZE = 100 + + +class FilterOp(str, Enum): + """Comparison operators for filter conditions. + + Negation is handled by the ``negate`` flag on FilterCondition, + not by separate operators. + """ + eq = "eq" + lt = "lt" + lte = "lte" + gt = "gt" + gte = "gte" + in_ = "in_" + ilike = "ilike" + is_null = "is_null" + + +class SortDir(str, Enum): + asc = "asc" + desc = "desc" + + +class FilterCondition(BaseModel): + """A single field comparison. + + Set ``negate=True`` to invert the result:: + + # name != 'bob' + FilterCondition(field="name", value="bob", negate=True) + + # priority NOT IN ('low', 'medium') + FilterCondition(field="priority", op=FilterOp.in_, value=["low", "medium"], negate=True) + """ + field: str + op: FilterOp = FilterOp.eq + value: Any = None + negate: bool = False + + +class FilterGroup(BaseModel): + """Group of conditions combined with AND or OR logic. + + Set ``negate=True`` to invert the entire group:: + + # NOT (status = 'archived' OR status = 'deleted') + FilterGroup( + logic="or", + negate=True, + conditions=[ + FilterCondition(field="status", value="archived"), + FilterCondition(field="status", value="deleted"), + ], + ) + + Groups can be nested:: + + # is_active = true AND (role = 'admin' OR role = 'owner') + FilterGroup( + logic="and", + conditions=[ + FilterCondition(field="is_active", value=True), + FilterGroup( + logic="or", + conditions=[ + FilterCondition(field="role", value="admin"), + FilterCondition(field="role", value="owner"), + ], + ), + ], + ) + """ + logic: Literal["and", "or"] = "and" + negate: bool = False + conditions: list[FilterCondition | FilterGroup] + + +class SortField(BaseModel): + field: str + dir: SortDir = SortDir.asc + + +class PaginationParams(BaseModel): + page: int = Field(default=1, ge=1) + page_size: int = Field(default=20, ge=0, le=MAX_PAGE_SIZE) + + +T = TypeVar("T") + + +class PaginatedResponse(BaseModel, Generic[T]): + items: list[T] + total_count: int + page: int + page_size: int + has_next: bool + has_prev: bool diff --git a/backend/app/schemas/task.py b/backend/app/schemas/task.py index 932b98e2..f747ee1c 100644 --- a/backend/app/schemas/task.py +++ b/backend/app/schemas/task.py @@ -207,8 +207,8 @@ class TaskListResponse(BaseModel): page: int page_size: int has_next: bool - sort_by: Optional[str] = None - sort_dir: Optional[str] = None + has_prev: bool + sorting: Optional[str] = None class TaskReorderItem(BaseModel): diff --git a/frontend/public/locales/en/errors.json b/frontend/public/locales/en/errors.json index b76645bb..eab4e77e 100644 --- a/frontend/public/locales/en/errors.json +++ b/frontend/public/locales/en/errors.json @@ -180,6 +180,8 @@ "IMPORT_INSUFFICIENT_PERMISSION": "Insufficient permissions for this project", "IMPORT_INVALID_STATUS_ID": "Invalid status ID for mapping", "NOTIFICATION_NOT_FOUND": "Notification not found", + "QUERY_INVALID_CONDITIONS": "Invalid filter conditions", + "QUERY_INVALID_SORT_FIELDS": "Invalid sort fields", "RATE_LIMITED": "Too many requests. Please wait a moment and try again.", "fallback": "Something went wrong. Please try again." } diff --git a/frontend/public/locales/es/errors.json b/frontend/public/locales/es/errors.json index 27c739c7..8960febc 100644 --- a/frontend/public/locales/es/errors.json +++ b/frontend/public/locales/es/errors.json @@ -180,6 +180,8 @@ "IMPORT_INSUFFICIENT_PERMISSION": "Permisos insuficientes para este proyecto", "IMPORT_INVALID_STATUS_ID": "ID de estado no válido para el mapeo", "NOTIFICATION_NOT_FOUND": "Notificación no encontrada", + "QUERY_INVALID_CONDITIONS": "Condiciones de filtro no válidas", + "QUERY_INVALID_SORT_FIELDS": "Campos de ordenamiento no válidos", "RATE_LIMITED": "Demasiadas solicitudes. Por favor espera un momento e intenta de nuevo.", "fallback": "Algo salió mal. Por favor intenta de nuevo." } diff --git a/frontend/src/__tests__/factories/task.factory.ts b/frontend/src/__tests__/factories/task.factory.ts index e0ae38eb..dae64272 100644 --- a/frontend/src/__tests__/factories/task.factory.ts +++ b/frontend/src/__tests__/factories/task.factory.ts @@ -77,8 +77,8 @@ export function buildTaskListResponse( page: 1, page_size: 50, has_next: false, - sort_by: null, - sort_dir: null, + has_prev: false, + sorting: null, }; } return { @@ -87,8 +87,8 @@ export function buildTaskListResponse( page: 1, page_size: 50, has_next: false, - sort_by: null, - sort_dir: null, + has_prev: false, + sorting: null, ...itemsOrOverrides, }; } diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 1c9a64a9..281fdd0e 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -93,10 +93,16 @@ export const apiClient = axios.create({ paramsSerializer: (params) => { const searchParams = new URLSearchParams(); for (const [key, value] of Object.entries(params)) { + if (value === null || value === undefined) continue; if (Array.isArray(value)) { - // For arrays, add each value with the same key (repeat format) - value.forEach((v) => searchParams.append(key, String(v))); - } else if (value !== null && value !== undefined) { + if (value.length > 0 && typeof value[0] === "object") { + // Arrays of objects (e.g. FilterCondition[]) → JSON string + searchParams.append(key, JSON.stringify(value)); + } else { + // Primitive arrays → repeated key format (key=1&key=2) + value.forEach((v) => searchParams.append(key, String(v))); + } + } else { searchParams.append(key, String(value)); } } diff --git a/frontend/src/api/generated/initiativeAPI.schemas.ts b/frontend/src/api/generated/initiativeAPI.schemas.ts index b5b31619..9e3d94df 100644 --- a/frontend/src/api/generated/initiativeAPI.schemas.ts +++ b/frontend/src/api/generated/initiativeAPI.schemas.ts @@ -1600,8 +1600,8 @@ export interface TaskListResponse { page: number; page_size: number; has_next: boolean; - sort_by: string | null; - sort_dir: string | null; + has_prev: boolean; + sorting: string | null; } export interface TaskMoveRequest { @@ -2043,6 +2043,100 @@ export interface VikunjaParseResult { total_tasks?: number; } +/** + * Comparison operators for filter conditions. + +Negation is handled by the ``negate`` flag on FilterCondition, +not by separate operators. + */ +export type FilterOp = (typeof FilterOp)[keyof typeof FilterOp]; + +export const FilterOp = { + eq: "eq", + lt: "lt", + lte: "lte", + gt: "gt", + gte: "gte", + in_: "in_", + ilike: "ilike", + is_null: "is_null", +} as const; + +/** + * A single field comparison. + +Set ``negate=True`` to invert the result:: + + # name != 'bob' + FilterCondition(field="name", value="bob", negate=True) + + # priority NOT IN ('low', 'medium') + FilterCondition(field="priority", op=FilterOp.in_, value=["low", "medium"], negate=True) + */ +export interface FilterCondition { + field: string; + op?: FilterOp; + value?: unknown; + negate?: boolean; +} + +export type FilterGroupLogic = (typeof FilterGroupLogic)[keyof typeof FilterGroupLogic]; + +export const FilterGroupLogic = { + and: "and", + or: "or", +} as const; + +/** + * Group of conditions combined with AND or OR logic. + +Set ``negate=True`` to invert the entire group:: + + # NOT (status = 'archived' OR status = 'deleted') + FilterGroup( + logic="or", + negate=True, + conditions=[ + FilterCondition(field="status", value="archived"), + FilterCondition(field="status", value="deleted"), + ], + ) + +Groups can be nested:: + + # is_active = true AND (role = 'admin' OR role = 'owner') + FilterGroup( + logic="and", + conditions=[ + FilterCondition(field="is_active", value=True), + FilterGroup( + logic="or", + conditions=[ + FilterCondition(field="role", value="admin"), + FilterCondition(field="role", value="owner"), + ], + ), + ], + ) + */ +export interface FilterGroup { + logic?: FilterGroupLogic; + negate?: boolean; + conditions: (FilterCondition | FilterGroup)[]; +} + +export type SortDir = (typeof SortDir)[keyof typeof SortDir]; + +export const SortDir = { + asc: "asc", + desc: "desc", +} as const; + +export interface SortField { + field: string; + dir?: SortDir; +} + export type GetVersionApiV1VersionGet200 = { [key: string]: string }; export type GetLatestDockerhubVersionApiV1VersionLatestGet200 = { [key: string]: string | null }; @@ -2130,18 +2224,11 @@ export type ProjectActivityFeedApiV1ProjectsProjectIdActivityGetParams = { }; export type ListTasksApiV1TasksGetParams = { - project_id?: number | null; scope?: "global" | "global_created" | null; - assignee_ids?: string[] | null; - task_status_ids?: number[] | null; - priorities?: TaskPriority[] | null; - status_category?: TaskStatusCategory[] | null; - initiative_ids?: number[] | null; - guild_ids?: number[] | null; /** - * Filter by tag IDs + * JSON list of filter conditions. Each object: {"field": "", "op": "", "value": }. Any Task column is valid plus virtual fields: status_category, assignee_ids, tag_ids, initiative_ids. */ - tag_ids?: number[] | null; + conditions?: FilterCondition[]; /** * Include archived tasks */ @@ -2156,13 +2243,9 @@ export type ListTasksApiV1TasksGetParams = { */ page_size?: number; /** - * Sort field(s), comma-separated: sort_order, title, due_date, start_date, priority, created_at, updated_at, date_group + * JSON list of sort fields: [{"field": "due_date", "dir": "desc"}] */ - sort_by?: string | null; - /** - * Sort direction(s), comma-separated: asc or desc (one per sort field) - */ - sort_dir?: string | null; + sorting?: SortField[]; }; export type ArchiveDoneTasksApiV1TasksArchiveDonePostParams = { diff --git a/frontend/src/components/projects/ProjectTasksSection.tsx b/frontend/src/components/projects/ProjectTasksSection.tsx index 2a786a9b..b8cec6c2 100644 --- a/frontend/src/components/projects/ProjectTasksSection.tsx +++ b/frontend/src/components/projects/ProjectTasksSection.tsx @@ -24,6 +24,7 @@ import type { LucideIcon } from "lucide-react"; import { toast } from "sonner"; import type { + FilterCondition, ListTasksApiV1TasksGetParams, TaskListRead, TaskListReadRecurrenceStrategy, @@ -197,12 +198,19 @@ export const ProjectTasksSection = ({ const lastKanbanOverRef = useRef(null); // Fetch tasks with server-side filtering (page_size=0 fetches all for drag-and-drop) + const conditions: FilterCondition[] = [ + { field: "project_id", op: "eq", value: projectId }, + ...(assigneeFilters.length > 0 + ? [{ field: "assignee_ids", op: "in_" as const, value: assigneeFilters }] + : []), + ...(statusFilters.length > 0 + ? [{ field: "task_status_id", op: "in_" as const, value: statusFilters }] + : []), + ...(tagFilters.length > 0 ? [{ field: "tag_ids", op: "in_" as const, value: tagFilters }] : []), + ]; const taskListParams: ListTasksApiV1TasksGetParams = { - project_id: projectId, + conditions, page_size: 0, - ...(assigneeFilters.length > 0 && { assignee_ids: assigneeFilters }), - ...(statusFilters.length > 0 && { task_status_ids: statusFilters }), - ...(tagFilters.length > 0 && { tag_ids: tagFilters }), ...(showArchived && { include_archived: true }), }; diff --git a/frontend/src/components/tasks/TagTasksTable.tsx b/frontend/src/components/tasks/TagTasksTable.tsx index ec776984..03c1c19d 100644 --- a/frontend/src/components/tasks/TagTasksTable.tsx +++ b/frontend/src/components/tasks/TagTasksTable.tsx @@ -18,10 +18,13 @@ import { DataTable } from "@/components/ui/data-table"; import { useGuilds } from "@/hooks/useGuilds"; import { TaskDescriptionHoverCard } from "@/components/projects/TaskDescriptionHoverCard"; import type { + SortField, TaskListRead, TaskPriority, TaskStatusCategory, TaskStatusRead, + FilterCondition, + ListTasksApiV1TasksGetParams, } from "@/api/generated/initiativeAPI.schemas"; import { SortIcon } from "@/components/SortIcon"; import { dateSortingFn, prioritySortingFn } from "@/lib/sorting"; @@ -80,8 +83,7 @@ export const TagTasksTable = ({ tagId }: TagTasksTableProps) => { const [page, setPageState] = useState(() => searchParams.page ?? 1); const [pageSize, setPageSize] = useState(TAG_TASKS_PAGE_SIZE); - const [sortBy, setSortBy] = useState("due_date"); - const [sortDir, setSortDir] = useState("asc"); + const [sorting, setSorting] = useState([{ field: "due_date", dir: "asc" }]); const statusOptions = useMemo( () => [ @@ -112,20 +114,18 @@ export const TagTasksTable = ({ tagId }: TagTasksTableProps) => { ); const handleSortingChange = useCallback( - (sorting: SortingState) => { - if (sorting.length > 0) { - const col = sorting[0]; - const field = SORT_FIELD_MAP[col.id]; - if (field) { - setSortBy(field); - setSortDir(col.desc ? "desc" : "asc"); - } else { - setSortBy(undefined); - setSortDir(undefined); - } + (tableSorting: SortingState) => { + if (tableSorting.length > 0) { + const fields: SortField[] = tableSorting + .map((col) => { + const field = SORT_FIELD_MAP[col.id]; + if (!field) return null; + return { field, dir: col.desc ? "desc" : "asc" } as SortField; + }) + .filter((f): f is SortField => f !== null); + setSorting(fields); } else { - setSortBy(undefined); - setSortDir(undefined); + setSorting([]); } setPage(1); }, @@ -137,33 +137,45 @@ export const TagTasksTable = ({ tagId }: TagTasksTableProps) => { setPage(1); }, [statusFilters, priorityFilters, setPage]); - const taskParams = { - tag_ids: [tagId], - ...(statusFilters.length > 0 && { status_category: statusFilters }), - ...(priorityFilters.length > 0 && { priorities: priorityFilters }), + const taskConditions: FilterCondition[] = [ + { field: "tag_ids", op: "in_", value: [tagId] }, + ...(statusFilters.length > 0 + ? [{ field: "status_category", op: "in_" as const, value: statusFilters }] + : []), + ...(priorityFilters.length > 0 + ? [{ field: "priority", op: "in_" as const, value: priorityFilters }] + : []), + ]; + const taskParams: ListTasksApiV1TasksGetParams = { + conditions: taskConditions, page, page_size: pageSize, - ...(sortBy && { sort_by: sortBy }), - ...(sortDir && { sort_dir: sortDir }), - } as Parameters[0]; + sorting: sorting.length > 0 ? sorting : undefined, + }; const tasksQuery = useTasks(taskParams, { placeholderData: keepPreviousData }); const prefetchPage = useCallback( (targetPage: number) => { if (targetPage < 1) return; - const params = { - tag_ids: [tagId], - ...(statusFilters.length > 0 && { status_category: statusFilters }), - ...(priorityFilters.length > 0 && { priorities: priorityFilters }), + const conditions: FilterCondition[] = [ + { field: "tag_ids", op: "in_", value: [tagId] }, + ...(statusFilters.length > 0 + ? [{ field: "status_category", op: "in_" as const, value: statusFilters }] + : []), + ...(priorityFilters.length > 0 + ? [{ field: "priority", op: "in_" as const, value: priorityFilters }] + : []), + ]; + const params: ListTasksApiV1TasksGetParams = { + conditions, page: targetPage, page_size: pageSize, - ...(sortBy && { sort_by: sortBy }), - ...(sortDir && { sort_dir: sortDir }), - } as Parameters[0]; + sorting: sorting.length > 0 ? sorting : undefined, + }; void prefetchTasks(params); }, - [tagId, statusFilters, priorityFilters, pageSize, sortBy, sortDir, prefetchTasks] + [tagId, statusFilters, priorityFilters, pageSize, sorting, prefetchTasks] ); const { mutateAsync: updateTaskStatusMutate, isPending: isUpdatingTaskStatus } = useUpdateTask({ diff --git a/frontend/src/components/ui/data-table.tsx b/frontend/src/components/ui/data-table.tsx index 698564da..5dbfedef 100644 --- a/frontend/src/components/ui/data-table.tsx +++ b/frontend/src/components/ui/data-table.tsx @@ -282,7 +282,7 @@ export function DataTable({ getRowId: getRowId, getPaginationRowModel: enablePagination && !manualPagination ? getPaginationRowModel() : undefined, - getSortedRowModel: manualSorting ? undefined : getSortedRowModel(), + getSortedRowModel: manualSorting && !groupingEnabled ? undefined : getSortedRowModel(), manualSorting: manualSorting, getFilteredRowModel: getFilteredRowModel(), getGroupedRowModel: groupingEnabled ? getGroupedRowModel() : undefined, diff --git a/frontend/src/hooks/useGlobalTasksTable.ts b/frontend/src/hooks/useGlobalTasksTable.ts index e3600687..15d3158c 100644 --- a/frontend/src/hooks/useGlobalTasksTable.ts +++ b/frontend/src/hooks/useGlobalTasksTable.ts @@ -15,11 +15,13 @@ import { useUpdateTask } from "@/hooks/useTasks"; import type { ListTasksApiV1TasksGetParams, ProjectRead, + SortField, TaskListRead, TaskListResponse, TaskPriority, TaskStatusCategory, TaskStatusRead, + FilterCondition, } from "@/api/generated/initiativeAPI.schemas"; import { getItem, setItem } from "@/lib/storage"; import { useGuilds } from "@/hooks/useGuilds"; @@ -115,8 +117,10 @@ export function useGlobalTasksTable({ scope, storageKeyPrefix }: UseGlobalTasksT // --- Pagination state --- const [page, setPageState] = useState(() => searchParams.page ?? 1); const [pageSize, setPageSize] = useState(PAGE_SIZE); - const [sortBy, setSortBy] = useState("date_group,due_date"); - const [sortDir, setSortDir] = useState("asc,asc"); + const [sorting, setSorting] = useState([ + { field: "date_group", dir: "asc" }, + { field: "due_date", dir: "asc" }, + ]); const setPage = useCallback( (updater: number | ((prev: number) => number)) => { @@ -137,22 +141,22 @@ export function useGlobalTasksTable({ scope, storageKeyPrefix }: UseGlobalTasksT ); const handleSortingChange = useCallback( - (sorting: SortingState) => { - if (sorting.length > 0) { - const fields = sorting.map((s) => SORT_FIELD_MAP[s.id]).filter(Boolean); - const dirs = sorting - .filter((s) => SORT_FIELD_MAP[s.id]) - .map((s) => (s.desc ? "desc" : "asc")); - if (fields.length > 0) { - setSortBy(fields.join(",")); - setSortDir(dirs.join(",")); - } else { - setSortBy(undefined); - setSortDir(undefined); + (tableSorting: SortingState) => { + if (tableSorting.length > 0) { + const fields: SortField[] = tableSorting + .map((col) => { + const field = SORT_FIELD_MAP[col.id]; + if (!field) return null; + return { field, dir: col.desc ? "desc" : "asc" } as SortField; + }) + .filter((f): f is SortField => f !== null); + // date_group needs due_date as secondary sort for meaningful ordering + if (fields.length === 1 && fields[0].field === "date_group") { + fields.push({ field: "due_date", dir: fields[0].dir ?? "asc" }); } + setSorting(fields); } else { - setSortBy(undefined); - setSortDir(undefined); + setSorting([]); } setPage(1); }, @@ -165,19 +169,26 @@ export function useGlobalTasksTable({ scope, storageKeyPrefix }: UseGlobalTasksT }, [statusFilters, priorityFilters, guildFilters, setPage]); // --- Tasks query --- - const tasksParams = useMemo(() => { - const params: ListTasksApiV1TasksGetParams = { + const tasksParams = useMemo((): ListTasksApiV1TasksGetParams => { + const conditions: FilterCondition[] = [ + ...(statusFilters.length > 0 + ? [{ field: "status_category", op: "in_" as const, value: statusFilters }] + : []), + ...(priorityFilters.length > 0 + ? [{ field: "priority", op: "in_" as const, value: priorityFilters }] + : []), + ...(guildFilters.length > 0 + ? [{ field: "guild_id", op: "in_" as const, value: guildFilters }] + : []), + ]; + return { scope: scope as ListTasksApiV1TasksGetParams["scope"], + conditions: conditions.length > 0 ? conditions : undefined, page, page_size: pageSize, - sort_by: sortBy ?? undefined, - sort_dir: sortDir ?? undefined, + sorting: sorting.length > 0 ? sorting : undefined, }; - if (statusFilters.length > 0) params.status_category = statusFilters; - if (priorityFilters.length > 0) params.priorities = priorityFilters; - if (guildFilters.length > 0) params.guild_ids = guildFilters; - return params; - }, [scope, statusFilters, priorityFilters, guildFilters, page, pageSize, sortBy, sortDir]); + }, [scope, statusFilters, priorityFilters, guildFilters, page, pageSize, sorting]); const tasksQuery = useQuery({ queryKey: getListTasksApiV1TasksGetQueryKey(tasksParams), @@ -188,10 +199,7 @@ export function useGlobalTasksTable({ scope, storageKeyPrefix }: UseGlobalTasksT const prefetchPage = useCallback( (targetPage: number) => { if (targetPage < 1) return; - const prefetchParams: ListTasksApiV1TasksGetParams = { - ...tasksParams, - page: targetPage, - }; + const prefetchParams: ListTasksApiV1TasksGetParams = { ...tasksParams, page: targetPage }; void localQueryClient.prefetchQuery({ queryKey: getListTasksApiV1TasksGetQueryKey(prefetchParams), diff --git a/frontend/src/pages/GuildDashboardPage.tsx b/frontend/src/pages/GuildDashboardPage.tsx index 619b26c9..db2a7423 100644 --- a/frontend/src/pages/GuildDashboardPage.tsx +++ b/frontend/src/pages/GuildDashboardPage.tsx @@ -29,17 +29,20 @@ import { Button } from "@/components/ui/button"; import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; import { Skeleton } from "@/components/ui/skeleton"; import { FavoriteProjectButton } from "@/components/projects/FavoriteProjectButton"; -import type { ListTasksApiV1TasksGetParams } from "@/api/generated/initiativeAPI.schemas"; -import { TaskStatusCategory } from "@/api/generated/initiativeAPI.schemas"; +import { + TaskStatusCategory, + ListTasksApiV1TasksGetParams, +} from "@/api/generated/initiativeAPI.schemas"; const DASHBOARD_TASK_PARAMS: ListTasksApiV1TasksGetParams = { - status_category: [ - TaskStatusCategory.backlog, - TaskStatusCategory.todo, - TaskStatusCategory.in_progress, + conditions: [ + { + field: "status_category", + op: "in_", + value: [TaskStatusCategory.backlog, TaskStatusCategory.todo, TaskStatusCategory.in_progress], + }, ], - sort_by: "due_date", - sort_dir: "asc", + sorting: [{ field: "due_date", dir: "asc" }], page_size: 10, }; diff --git a/frontend/src/routes/_serverRequired/_authenticated/created-tasks.tsx b/frontend/src/routes/_serverRequired/_authenticated/created-tasks.tsx index 7b418b77..bf0ee6f5 100644 --- a/frontend/src/routes/_serverRequired/_authenticated/created-tasks.tsx +++ b/frontend/src/routes/_serverRequired/_authenticated/created-tasks.tsx @@ -37,16 +37,25 @@ export const Route = createFileRoute("/_serverRequired/_authenticated/created-ta const { queryClient } = context; const { statusFilters, priorityFilters, guildFilters } = readStoredFilters(); - const params: Record = { + const conditions: Array<{ field: string; op: string; value: unknown }> = []; + if (statusFilters.length > 0) + conditions.push({ field: "status_category", op: "in_", value: statusFilters }); + if (priorityFilters.length > 0) + conditions.push({ field: "priority", op: "in_", value: priorityFilters }); + if (guildFilters.length > 0) + conditions.push({ field: "guild_id", op: "in_", value: guildFilters }); + + const defaultSorting = [ + { field: "date_group", dir: "asc" }, + { field: "due_date", dir: "asc" }, + ]; + const params: Record = { scope: "global_created", page: 1, page_size: PAGE_SIZE, - sort_by: "date_group,due_date", - sort_dir: "asc,asc", + sorting: JSON.stringify(defaultSorting), }; - if (statusFilters.length > 0) params.status_category = statusFilters; - if (priorityFilters.length > 0) params.priorities = priorityFilters; - if (guildFilters.length > 0) params.guild_ids = guildFilters; + if (conditions.length > 0) params.conditions = JSON.stringify(conditions); try { await queryClient.ensureQueryData({ @@ -59,8 +68,7 @@ export const Route = createFileRoute("/_serverRequired/_authenticated/created-ta guildFilters, 1, PAGE_SIZE, - "date_group,due_date", - "asc,asc", + "date_group+due_date", ], queryFn: () => apiClient.get("/tasks/", { params }).then((r) => r.data), staleTime: 30_000, diff --git a/frontend/src/routes/_serverRequired/_authenticated/g/$guildId/projects_.$projectId.tsx b/frontend/src/routes/_serverRequired/_authenticated/g/$guildId/projects_.$projectId.tsx index 8beffefb..107cfb28 100644 --- a/frontend/src/routes/_serverRequired/_authenticated/g/$guildId/projects_.$projectId.tsx +++ b/frontend/src/routes/_serverRequired/_authenticated/g/$guildId/projects_.$projectId.tsx @@ -61,12 +61,18 @@ export const Route = createFileRoute( const { assigneeFilters, statusFilters, showArchived } = getStoredFilters(projectId); // Build task query params (page_size=0 fetches all for drag-and-drop) - const taskParams: Record = { - project_id: projectId, + const conditions: Array<{ field: string; op: string; value: unknown }> = [ + { field: "project_id", op: "eq", value: projectId }, + ]; + if (assigneeFilters.length > 0) + conditions.push({ field: "assignee_ids", op: "in_", value: assigneeFilters }); + if (statusFilters.length > 0) + conditions.push({ field: "task_status_id", op: "in_", value: statusFilters }); + + const taskParams: Record = { page_size: 0, + conditions: JSON.stringify(conditions), }; - if (assigneeFilters.length > 0) taskParams.assignee_ids = assigneeFilters; - if (statusFilters.length > 0) taskParams.task_status_ids = statusFilters; if (showArchived) taskParams.include_archived = true; // Prefetch in background - don't block navigation on failure diff --git a/frontend/src/routes/_serverRequired/_authenticated/index.tsx b/frontend/src/routes/_serverRequired/_authenticated/index.tsx index 930f8c3c..f386bc8f 100644 --- a/frontend/src/routes/_serverRequired/_authenticated/index.tsx +++ b/frontend/src/routes/_serverRequired/_authenticated/index.tsx @@ -39,16 +39,25 @@ export const Route = createFileRoute("/_serverRequired/_authenticated/")({ const { queryClient } = context; const { statusFilters, priorityFilters, guildFilters } = readStoredFilters(); - const params: Record = { + const conditions: Array<{ field: string; op: string; value: unknown }> = []; + if (statusFilters.length > 0) + conditions.push({ field: "status_category", op: "in_", value: statusFilters }); + if (priorityFilters.length > 0) + conditions.push({ field: "priority", op: "in_", value: priorityFilters }); + if (guildFilters.length > 0) + conditions.push({ field: "guild_id", op: "in_", value: guildFilters }); + + const defaultSorting = [ + { field: "date_group", dir: "asc" }, + { field: "due_date", dir: "asc" }, + ]; + const params: Record = { scope: "global", page: 1, page_size: PAGE_SIZE, - sort_by: "date_group,due_date", - sort_dir: "asc,asc", + sorting: JSON.stringify(defaultSorting), }; - if (statusFilters.length > 0) params.status_category = statusFilters; - if (priorityFilters.length > 0) params.priorities = priorityFilters; - if (guildFilters.length > 0) params.guild_ids = guildFilters; + if (conditions.length > 0) params.conditions = JSON.stringify(conditions); try { await queryClient.ensureQueryData({ @@ -61,8 +70,7 @@ export const Route = createFileRoute("/_serverRequired/_authenticated/")({ guildFilters, 1, PAGE_SIZE, - "date_group,due_date", - "asc,asc", + "date_group+due_date", ], queryFn: () => apiClient.get("/tasks/", { params }).then((r) => r.data), staleTime: 30_000,