From 59e53d11697519b611068a4671c9c0c749538597 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Thu, 27 Nov 2025 10:14:05 -0800 Subject: [PATCH 1/4] Set up authorization service, with an AuthContext for verifying role assignments and a separate authorize func to replace validate_access_requests --- .../datajunction_server/api/data.py | 25 +- .../datajunction_server/api/dimensions.py | 82 +- .../datajunction_server/api/namespaces.py | 22 +- .../datajunction_server/api/nodes.py | 47 +- .../datajunction_server/config.py | 12 + .../datajunction_server/database/user.py | 9 + .../internal/access/authentication/basic.py | 31 +- .../internal/access/authorization.py | 592 ++++- .../datajunction_server/models/access.py | 22 +- datajunction-server/tests/conftest.py | 9 + .../tests/internal/authorization_test.py | 1985 +++++++++++++++++ 11 files changed, 2694 insertions(+), 142 deletions(-) create mode 100644 datajunction-server/tests/internal/authorization_test.py diff --git a/datajunction-server/datajunction_server/api/data.py b/datajunction-server/datajunction_server/api/data.py index abe5c0629..5baad69ee 100644 --- a/datajunction-server/datajunction_server/api/data.py +++ b/datajunction-server/datajunction_server/api/data.py @@ -29,8 +29,9 @@ ) from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( + AccessDenialMode, + authorize, validate_access, - validate_access_requests, ) from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.models import access @@ -90,16 +91,16 @@ async def add_availability_state( # Source nodes require that any availability states set are for one of the defined tables node_revision = node.current # type: ignore - validate_access_requests( - validate_access, - current_user, - [ + await authorize( + session=session, + user=current_user, + resource_requests=[ access.ResourceRequest( verb=access.ResourceAction.WRITE, access_object=access.Resource.from_node(node_revision), ), ], - True, + on_denied=AccessDenialMode.RAISE, ) if node.current.type == NodeType.SOURCE: # type: ignore @@ -215,16 +216,16 @@ async def remove_availability_state( ), ) - validate_access_requests( - validate_access, - current_user, - [ + await authorize( + session=session, + user=current_user, + resource_requests=[ access.ResourceRequest( verb=access.ResourceAction.WRITE, - access_object=access.Resource.from_node(node), + access_object=access.Resource.from_node(node.current), # type: ignore ), ], - True, + on_denied=AccessDenialMode.RAISE, ) # Save the old availability state for history record diff --git a/datajunction-server/datajunction_server/api/dimensions.py b/datajunction-server/datajunction_server/api/dimensions.py index 2ebb7b5ca..43e35f46c 100644 --- a/datajunction-server/datajunction_server/api/dimensions.py +++ b/datajunction-server/datajunction_server/api/dimensions.py @@ -14,8 +14,9 @@ from datajunction_server.database.user import User from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( + AccessDenialMode, validate_access, - validate_access_requests, + authorize, ) from datajunction_server.models import access from datajunction_server.models.node import NodeIndegreeOutput @@ -80,29 +81,39 @@ async def find_nodes_with_dimension( List all nodes that have the specified dimension """ dimension_node = await get_node_by_name(session, name) + + # Ensure the user has access to the dimension node first + await authorize( + session=session, + user=current_user, + resource_requests=[ + access.ResourceRequest( + verb=access.ResourceAction.READ, + access_object=access.Resource.from_node(dimension_node), + ) + ], + raise_on_denied=True, + ) + nodes = await get_nodes_with_common_dimensions( session, [dimension_node], node_type if node_type else None, ) - approvals = [ - approval.access_object.name - for approval in validate_access_requests( - validate_access, - current_user, - [ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource( - name=node.name, - resource_type=access.ResourceType.NODE, - owner="", - ), - ) - for node in nodes - ], - ) - ] + + # Only return nodes the user has access to + approvals = await authorize( + session=session, + user=current_user, + resource_requests=[ + access.ResourceRequest( + verb=access.ResourceAction.READ, + access_object=access.Resource.from_node(node), + ) + for node in nodes + ], + on_denied=AccessDenialMode.FILTER, + ) return [NodeNameOutput(name=node.name) for node in nodes if node.name in approvals] @@ -125,22 +136,19 @@ async def find_nodes_with_common_dimensions( [await get_node_by_name(session, dim) for dim in dimension], # type: ignore node_type, ) - approvals = [ - approval.access_object.name - for approval in validate_access_requests( - validate_access, - current_user, - [ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource( - name=node.name, - resource_type=access.ResourceType.NODE, - owner="", - ), - ) - for node in nodes - ], - ) + resource_requests = await authorize( + session=session, + user=current_user, + resource_requests=[ + access.ResourceRequest( + verb=access.ResourceAction.READ, + access_object=access.Resource.from_node(node), + ) + for node in nodes + ], + on_denied=AccessDenialMode.FILTER, + ) + approved_resource_names = [ + request.access_object.name for request in resource_requests ] - return [NodeNameOutput(name=node.name) for node in nodes if node.name in approvals] + return [NodeNameOutput(name=node.name) for node in nodes if node.name in approved_resource_names] diff --git a/datajunction-server/datajunction_server/api/namespaces.py b/datajunction-server/datajunction_server/api/namespaces.py index 42a072bcf..69dfea142 100644 --- a/datajunction-server/datajunction_server/api/namespaces.py +++ b/datajunction-server/datajunction_server/api/namespaces.py @@ -19,8 +19,8 @@ from datajunction_server.models.dimensionlink import LinkType from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( - validate_access, - validate_access_requests, + AccessDenialMode, + authorize, ) from datajunction_server.internal.namespaces import ( create_namespace, @@ -110,9 +110,6 @@ async def create_node_namespace( async def list_namespaces( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), ) -> List[NamespaceOutput]: """ List namespaces with the number of nodes contained in them @@ -125,13 +122,14 @@ async def list_namespaces( ) for record in results ] - approvals = validate_access_requests( - validate_access, - current_user, - resource_requests=resource_requests, - ) - approved_namespaces: List[str] = [ - request.access_object.name for request in approvals + approved_namespaces = [ + request.access_object.name + for request in await authorize( + session=session, + user=current_user, + resource_requests=resource_requests, + on_denied=AccessDenialMode.FILTER, + ) ] return [ NamespaceOutput(namespace=record.namespace, num_nodes=record.num_nodes) diff --git a/datajunction-server/datajunction_server/api/nodes.py b/datajunction-server/datajunction_server/api/nodes.py index 0ead3ca85..60c2add30 100644 --- a/datajunction-server/datajunction_server/api/nodes.py +++ b/datajunction-server/datajunction_server/api/nodes.py @@ -16,6 +16,7 @@ from sqlalchemy.sql.operators import is_ from starlette.requests import Request + from datajunction_server.api.helpers import ( get_catalog_by_name, get_column, @@ -42,8 +43,9 @@ ) from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( + AccessDenialMode, + authorize, validate_access, - validate_access_requests, ) from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.internal.nodes import ( @@ -238,28 +240,24 @@ async def list_nodes( *, session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), ) -> List[str]: """ List the available nodes. """ nodes = await Node.find(session, prefix, node_type) # type: ignore - return [ - approval.access_object.name - for approval in validate_access_requests( - validate_access, - current_user, - [ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource.from_node(node), - ) - for node in nodes - ], - ) - ] + approved_requests = await authorize( + session=session, + user=current_user, + resource_requests=[ + access.ResourceRequest( + verb=access.ResourceAction.READ, + access_object=access.Resource.from_node(node), + ) + for node in nodes + ], + on_denied=AccessDenialMode.FILTER, + ) + return [req.access_object.name for req in approved_requests] @router.get("/nodes/details/", response_model=List[NodeIndexItem]) @@ -269,9 +267,6 @@ async def list_all_nodes_with_details( *, session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), ) -> List[NodeIndexItem]: """ List the available nodes. @@ -302,11 +297,11 @@ async def list_all_nodes_with_details( settings.node_list_max, ) approvals = [ - approval.access_object.name - for approval in validate_access_requests( - validate_access, - current_user, - [ + request.access_object.name + for request in await authorize( + session=session, + user=current_user, + resource_requests=[ access.ResourceRequest( verb=access.ResourceAction.READ, access_object=access.Resource( diff --git a/datajunction-server/datajunction_server/config.py b/datajunction-server/datajunction_server/config.py index ed6941ca6..1f6d307a9 100644 --- a/datajunction-server/datajunction_server/config.py +++ b/datajunction-server/datajunction_server/config.py @@ -157,6 +157,18 @@ class Settings(BaseSettings): # pragma: no cover # or a custom implementation of the GroupMembershipProvider interface group_membership_provider: str = "postgres" + # Authorization configuration + # Provider for authorization checks: + # - "rbac": Role-based access control (default) + # - "passthrough": Always approve (testing/development) + # - Custom implementations can be plugged in + authorization_provider: str = "rbac" + + # Default access policy when no explicit RBAC rule exists: + # - "permissive": Allow by default (OSS-friendly, lock down as needed) + # - "restrictive": Deny by default (Enterprise, explicitly grant access) + default_access_policy: str = "permissive" # or "restrictive" + # Interval in seconds with which to expire caching of any indexes index_cache_expire: int = 60 diff --git a/datajunction-server/datajunction_server/database/user.py b/datajunction-server/datajunction_server/database/user.py index 5b8d56eef..e64f60eed 100644 --- a/datajunction-server/datajunction_server/database/user.py +++ b/datajunction-server/datajunction_server/database/user.py @@ -34,6 +34,7 @@ from datajunction_server.database.notification_preference import ( NotificationPreference, ) + from datajunction_server.database.rbac import RoleAssignment from datajunction_server.database.tag import Tag logger = logging.getLogger(__name__) @@ -144,6 +145,7 @@ class User(Base): ) # Group membership relationships (for kind=GROUP) + # Groups that this user owns (for kind=GROUP) group_members: Mapped[list["GroupMember"]] = relationship( "GroupMember", foreign_keys="GroupMember.group_id", @@ -156,6 +158,13 @@ class User(Base): viewonly=True, ) + # RBAC role assignments (for authorization) + role_assignments: Mapped[list["RoleAssignment"]] = relationship( + "RoleAssignment", + foreign_keys="RoleAssignment.principal_id", + viewonly=True, + ) + @classmethod async def get_by_username( cls, diff --git a/datajunction-server/datajunction_server/internal/access/authentication/basic.py b/datajunction-server/datajunction_server/internal/access/authentication/basic.py index 2026e069d..7bc332151 100644 --- a/datajunction-server/datajunction_server/internal/access/authentication/basic.py +++ b/datajunction-server/datajunction_server/internal/access/authentication/basic.py @@ -5,10 +5,11 @@ import logging from passlib.context import CryptContext -from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql.base import ExecutableOption +from sqlalchemy.orm import selectinload +from datajunction_server.database.rbac import RoleAssignment, Role from datajunction_server.database.user import User from datajunction_server.errors import DJAuthenticationException, DJError, ErrorCode @@ -33,19 +34,29 @@ def get_password_hash(password) -> str: async def get_user( username: str, session: AsyncSession, - *options: ExecutableOption, + options: list[ExecutableOption] | None = None, ) -> User: """ Get a DJ user """ - user = ( - ( - await session.execute( - select(User).where(User.username == username).options(*options), - ) - ) - .unique() - .scalar_one_or_none() + from datajunction_server.database.group_member import GroupMember + + user = await User.get_by_username( + session=session, + username=username, + options=options + or [ + # Load user's direct role assignments + selectinload(User.role_assignments) + .selectinload(RoleAssignment.role) + .selectinload(Role.scopes), + # Load user's group memberships and the groups' role assignments + selectinload(User.member_of) + .selectinload(GroupMember.group) + .selectinload(User.role_assignments) + .selectinload(RoleAssignment.role) + .selectinload(Role.scopes), + ], ) if not user: raise DJAuthenticationException( diff --git a/datajunction-server/datajunction_server/internal/access/authorization.py b/datajunction-server/datajunction_server/internal/access/authorization.py index 8fd26aa71..ca06acbd7 100644 --- a/datajunction-server/datajunction_server/internal/access/authorization.py +++ b/datajunction-server/datajunction_server/internal/access/authorization.py @@ -1,75 +1,579 @@ """ -Authorization related functionality +Authorization related functionality using pluggable services. + +This module provides a pluggable authorization system that works with pre-loaded +user data (no async DB queries needed during authorization): + +- User's roles/assignments are eagerly loaded when fetching the user +- AuthorizationService performs sync in-memory checks +- No changes needed to existing API endpoints +- Keeps the existing validate_access() pattern + +Example custom implementation: +```python +class CustomAuthService(AuthorizationService): + name = "custom" + + def authorize(self, auth_context, requests): + # Sync in-memory authorization logic + return requests +``` """ -from typing import Iterable, List, Union +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum +from functools import lru_cache +from typing import List, Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload -from datajunction_server.database.node import Node, NodeRevision +from datajunction_server.internal.access.group_membership import ( + GroupMembershipService, + get_group_membership_service, +) +from datajunction_server.database.rbac import RoleAssignment from datajunction_server.database.user import User from datajunction_server.models.access import ( AccessControl, - AccessControlStore, + ResourceAction, ResourceRequest, + ResourceType, ValidateAccessFn, ) -from datajunction_server.models.user import UserOutput +from datajunction_server.utils import SEPARATOR, get_settings +settings = get_settings() -def validate_access_requests( - validate_access: ValidateAccessFn, - user: User, - resource_requests: Iterable[ResourceRequest], - raise_: bool = False, -) -> List[Union[NodeRevision, Node, ResourceRequest]]: - """ - Validate a set of access requests. Only approved requests are returned. - """ - if user is None: - return list(resource_requests) # pragma: no cover - access_control = AccessControlStore( - validate_access=validate_access, - user=UserOutput( - id=user.id, + +# ============================================================================ +# Access Check Modes +# ============================================================================ + + +class AccessDenialMode(Enum): + """ + How to handle denied access requests. + """ + + FILTER = "filter" # Return only approved requests (default for list operations) + RAISE = "raise" # Raise exception if any denied (for single resource operations) + RETURN = ( + "return" # Return all requests with approved field set (for custom handling) + ) + + +# ============================================================================ +# Authorization Context +# ============================================================================ + + +@dataclass(frozen=True) +class AuthContext: + """ + Authorization context for a user. + + Contains all data needed to make authorization decisions, + pre-loaded and ready for fast in-memory checks. + + This separates authorization data from the full User model, + allowing for clean caching, testing, and type safety. + """ + + user_id: int + username: str + oauth_provider: Optional[str] + role_assignments: List[RoleAssignment] # Direct + groups, flattened + + @classmethod + async def from_user( + cls, + session: AsyncSession, + user: User, + group_membership_service: GroupMembershipService | None = None, + ) -> "AuthContext": + """ + Build authorization context from a User object. + + This loads all effective role assignments (direct + group-based) + using the configured GroupMembershipService. + + Args: + session: Database session + user: User to build context for + group_membership_service: Optional service override + + Returns: + AuthContext ready for authorization checks + """ + assignments = await cls.get_effective_assignments( + session=session, + user=user, + group_membership_service=group_membership_service, + ) + + return cls( + user_id=user.id, username=user.username, oauth_provider=user.oauth_provider, - ), - ) + role_assignments=assignments, + ) + + @classmethod + async def get_effective_assignments( + cls, + session: AsyncSession, + user: User, + group_membership_service: GroupMembershipService | None = None, + ) -> List[RoleAssignment]: + """ + Get all effective role assignments for a user (direct + group-based). + + This function: + 1. Takes user's direct role_assignments + 2. Calls GroupMembershipService to get groups (LDAP/local/etc.) + 3. Loads those groups' role_assignments from DJ database + 4. Returns flattened list + + Args: + session: Database session + user: User to get assignments for + group_membership_service: Optional service override + + Returns: + Flat list of all role assignments that apply to this user + """ + from datajunction_server.database.rbac import Role as RoleModel + + # Start with user's direct assignments + assignments = list(user.role_assignments) + + # Get group membership service + if group_membership_service is None: + group_membership_service = get_group_membership_service() + + # Get groups from service (could be LDAP, local DB, etc.) + group_usernames = await group_membership_service.get_user_groups( + session, + user.username, + ) - for request in resource_requests: - access_control.add_request(request) + if not group_usernames: + return assignments # No groups - validation_results = access_control.validate() - if raise_: - access_control.raise_if_invalid_requests() # pragma: no cover - return [result for result in validation_results if result.approved] + # Load groups from DJ database with their role_assignments + stmt = ( + select(User) + .where(User.username.in_(group_usernames)) + .options( + selectinload(User.role_assignments) + .selectinload(RoleAssignment.role) + .selectinload(RoleModel.scopes), + ) + ) + result = await session.execute(stmt) + groups = result.scalars().all() + + # Flatten group assignments into the list + for group in groups: + assignments.extend(group.role_assignments) + + return assignments + + +async def authorize( + session: AsyncSession, + user: User, + resource_requests: List[ResourceRequest], + *, + on_denied: AccessDenialMode = AccessDenialMode.FILTER, +) -> List[ResourceRequest]: + """ + Check access to resources with flexible denial handling. + + Args: + session: Database session + user: User requesting access + resource_requests: Resources to check access for + on_denied: How to handle denied requests: + - FILTER (default): Return only approved requests (for list operations) + - RAISE: Raise DJAuthorizationException if any denied (for single resource) + - RETURN: Return with approved field set (for custom handling) + + Returns: + List of resource requests (filtered or all, depending on on_denied mode) + + Raises: + DJAuthorizationException: If on_denied=RAISE and any requests are denied + """ + auth_context = await AuthContext.from_user(session, user) + auth_service = get_authorization_service() + all_requests = auth_service.authorize(auth_context, resource_requests) + + # Handle based on mode + if on_denied == AccessDenialMode.RETURN: + return all_requests + elif on_denied == AccessDenialMode.RAISE: + denied = [r for r in all_requests if not r.approved] + if denied: + from datajunction_server.errors import ( + DJAuthorizationException, + DJError, + ErrorCode, + ) + + raise DJAuthorizationException( + message=f"Access denied to {len(denied)} resource(s)", + errors=[ + DJError( + code=ErrorCode.UNAUTHORIZED_ACCESS, + message=( + f"{r.verb.value.upper()} access denied to " + f"{r.access_object.resource_type.value}: " + f"{r.access_object.name}" + ), + ) + for r in denied + ], + ) + return all_requests + # Default: FILTER + return [r for r in all_requests if r.approved] def validate_access() -> ValidateAccessFn: """ - A placeholder validate access dependency injected function - that returns a ValidateAccessFn that approves all requests + Default validate access function that uses the configured authorization service. + + This delegates to the pluggable service (RBAC, passthrough, custom, etc.) + using the AuthContext attached to the AccessControl object. """ + auth_service = get_authorization_service() def _(access_control: AccessControl): """ - Examines all requests in the AccessControl - and approves or denies each + Authorizes requests using the configured service. + """ + auth_context = getattr(access_control, "auth_context", None) + if not auth_context: + # No auth context - approve all (backward compat) + access_control.approve_all() + return + + # Use authorization service + requests_list = list(access_control.requests) + auth_service.authorize(auth_context, requests_list) + + return _ + + +# ============================================================================ +# New FastAPI-style Authorization Service +# ============================================================================ + + +class AuthorizationService(ABC): + """ + Abstract base class for authorization strategies. + + Authorization is performed on a pre-loaded authorization context. + + Implementations of this base class decide exactly how to authorize requests: + - RBACAuthorizationService: Uses pre-loaded roles/scopes (default) + - PassthroughAuthorizationService: Always approve (testing/permissive) + - Custom: Your own authorization logic + + Each implementation should define a `name` class attribute to register itself. + """ + + name: str # Subclasses must define this + + @abstractmethod + def authorize( + self, + auth_context: AuthContext, + requests: List[ResourceRequest], + ) -> List[ResourceRequest]: + """ + Authorize resource requests for a user. + + This method should mutate the `approved` field on each request + to indicate whether access is granted. Args: - access_control (AccessControl): The access control object - containing the access control state and requests. + auth_context: Pre-loaded authorization context with all needed data + requests: List of resource requests to authorize + + Returns: + The same list of requests with approved=True/False set on each + """ - Example: - if access_control.state == 'direct': - access_control.approve_all() - return - if access_control.user=='dj': - request.approve_all() - return +class RBACAuthorizationService(AuthorizationService): + """ + Default RBAC implementation using pre-loaded roles and scopes. - request.deny_all() + This implementation: + 1. Works on AuthContext with pre-loaded role_assignments (direct + groups) + 2. Falls back to default_access_policy if no explicit rule exists + 3. Respects role expiration + 4. Synchronous - works on eagerly loaded data + + Group Membership Integration: + - Supports pluggable GroupMembershipService (LDAP, local DB, etc.) + - Groups are loaded when building AuthContext via from_user() + - No DB queries during authorization - all data pre-loaded + """ + + name = "rbac" + + PERMISSION_HIERARCHY = { + ResourceAction.MANAGE: { + ResourceAction.MANAGE, + ResourceAction.DELETE, + ResourceAction.WRITE, + ResourceAction.EXECUTE, + ResourceAction.READ, + }, + ResourceAction.DELETE: { + ResourceAction.DELETE, + ResourceAction.WRITE, + ResourceAction.READ, + }, + ResourceAction.WRITE: { + ResourceAction.WRITE, + ResourceAction.READ, + }, + ResourceAction.EXECUTE: { + ResourceAction.EXECUTE, + ResourceAction.READ, + }, + ResourceAction.READ: { + ResourceAction.READ, + }, + } + + def authorize( + self, + auth_context: AuthContext, + requests: List[ResourceRequest], + ) -> List[ResourceRequest]: """ - access_control.approve_all() + Authorize using pre-loaded RBAC roles and scopes (sync). - return _ + Args: + auth_context: Pre-loaded authorization context with role assignments + requests: Resource requests to authorize + + Returns: + Same list of requests with approved=True/False set + """ + for request in requests: + has_grant = self.has_permission( + assignments=auth_context.role_assignments, + action=request.verb, + resource_type=request.access_object.resource_type, + resource_name=request.access_object.name, + ) + request.approved = ( + has_grant or settings.default_access_policy == "permissive" + ) + return requests + + @classmethod + def resource_matches_pattern(cls, resource_name: str, pattern: str) -> bool: + """ + Check if resource name matches a pattern with wildcard support. + + resource_matches_pattern("finance.revenue", "finance.*") --> True + resource_matches_pattern("finance.quarterly.revenue", "finance.*") --> True + resource_matches_pattern("users.alice.dashboard", "users.alice.*") --> True + resource_matches_pattern("marketing.revenue", "finance.*") --> False + resource_matches_pattern("anything", "*") --> True + resource_matches_pattern("finance", "finance.*") --> False + """ + if pattern == "*": + return True # Match everything + + if "*" not in pattern: + return resource_name == pattern # Exact match + + # Wildcard pattern: finance.* matches finance.revenue and finance.quarterly.revenue + # But NOT just "finance" (must have something after the dot) + pattern_prefix = pattern.rstrip("*").rstrip(SEPARATOR) + + if not pattern_prefix: + return True # Pattern was just "*" + + # Resource must start with pattern_prefix followed by a dot + # (not an exact match to pattern_prefix, that would be handled by exact pattern) + return resource_name.startswith(pattern_prefix + SEPARATOR) + + @classmethod + def has_permission( + cls, + assignments: List, + action: ResourceAction, + resource_type: ResourceType, + resource_name: str, + ) -> bool: + """ + Determine if a list of role assignments grants the requested permission. + + This method iterates through all provided role assignments, checking if any + grant the specified action on the given resource. Expired assignments are + automatically skipped. Returns True if at least one valid assignment grants + access, False otherwise. + + Args: + assignments: List of role assignments to check + action: The action being requested (READ, WRITE, etc.) + resource_type: Type of resource (NODE, NAMESPACE, etc.) + resource_name: Full name/identifier of the resource + + Returns: + True if permission is granted, False otherwise + """ + for assignment in assignments: + # Skip expired assignments + if assignment.expires_at and assignment.expires_at < datetime.now( + timezone.utc, + ): + continue + + # Check each scope in the role + for scope in assignment.role.scopes: + # Check if scope grants permission for this resource + if cls._scope_grants_permission( + scope, + action, + resource_type, + resource_name, + ): + return True + + return False + + @classmethod + def _scope_grants_permission( + cls, + scope, + action: ResourceAction, + resource_type: ResourceType, + resource_name: str, + ) -> bool: + """ + Check if a scope grants permission for a resource. + + Handles: + 1. Permission hierarchy (MANAGE > DELETE > WRITE > READ, EXECUTE > READ) + 2. Empty/None scope_value or "*" = global access + 3. Wildcard pattern matching (finance.*) + 4. Cross-resource-type: namespace scope covers nodes in that namespace + """ + # Check permission hierarchy: does scope.action grant the requested action? + granted_actions = cls.PERMISSION_HIERARCHY.get(scope.action, {scope.action}) + if action not in granted_actions: + return False + + # Handle global access (empty string, None, or "*" scope_value) + if not scope.scope_value or scope.scope_value == "" or scope.scope_value == "*": + # Global scope matches any resource of the same type + return scope.scope_type == resource_type + + # Same resource type - use pattern matching + if scope.scope_type == resource_type: + return cls.resource_matches_pattern(resource_name, scope.scope_value) + + # Cross-resource-type: namespace scope can cover nodes + if ( + scope.scope_type == ResourceType.NAMESPACE + and resource_type == ResourceType.NODE + ): + # Check if node name matches the namespace pattern + return cls.resource_matches_pattern(resource_name, scope.scope_value) + + # No match + return False + + +class PassthroughAuthorizationService(AuthorizationService): + """ + Always approves all requests without checking permissions. + + Useful for: + - Local development + - Testing + - Fully permissive deployments + - Gradual RBAC rollout (start permissive, add rules incrementally) + """ + + name = "passthrough" + + def authorize( + self, + auth_context: AuthContext, + requests: List[ResourceRequest], + ) -> List[ResourceRequest]: + """Approve all requests without checks (sync).""" + for request in requests: + request.approved = True + return requests + + +@lru_cache(maxsize=None) +def get_authorization_service() -> AuthorizationService: + """ + Factory function to get the configured authorization service. + + This is used as a FastAPI dependency. The service can be overridden + via app.dependency_overrides for testing or custom deployments. + + Built-in providers: + - "rbac": Role-based access control using roles/scopes tables (default) + - "passthrough": Always approve all requests + + Configure via environment variable: + ```bash + AUTHORIZATION_PROVIDER=rbac # or passthrough + ``` + + Custom providers can be added by: + 1. Subclassing AuthorizationService + 2. Defining a `name` class attribute + 3. Importing the class before app starts + + Example: + ```python + class ExampleAuthService(AuthorizationService): + name = "example" + + def authorize(self, user, requests): + # Your sync authorization logic + return requests + ``` + + Returns: + AuthorizationService implementation + + Raises: + ValueError: If the configured provider is unknown + """ + provider = getattr(settings, "authorization_provider", "rbac") + + # Discover all subclasses + providers = {} + for subclass in AuthorizationService.__subclasses__(): + if hasattr(subclass, "name"): + providers[subclass.name] = subclass + if subclass.name == provider: + return subclass() # type: ignore[abstract] + + available = ", ".join(sorted(providers.keys())) + raise ValueError( + f"Unknown authorization_provider: '{provider}'. " + f"Available providers: {available}", + ) diff --git a/datajunction-server/datajunction_server/models/access.py b/datajunction-server/datajunction_server/models/access.py index fea00d6bd..aa3727110 100644 --- a/datajunction-server/datajunction_server/models/access.py +++ b/datajunction-server/datajunction_server/models/access.py @@ -16,6 +16,7 @@ from datajunction_server.models.user import UserOutput if TYPE_CHECKING: + from datajunction_server.database.user import User from datajunction_server.sql.parsing.ast import Column @@ -123,11 +124,18 @@ class AccessControl(BaseModel): necessary to deny or approve a request """ + model_config = {"arbitrary_types_allowed": True} + user: str state: AccessControlState direct_requests: Set[ResourceRequest] indirect_requests: Set[ResourceRequest] validation_request_count: int + session: Optional[AsyncSession] = None # For RBAC permission checks (deprecated) + user_id: Optional[int] = None # User ID for RBAC lookups (deprecated) + user_object: Optional["User"] = ( + None # Full User object with role_assignments loaded + ) @property def requests(self) -> Set[ResourceRequest]: @@ -159,6 +167,8 @@ class AccessControlStore(BaseModel): An access control store tracks all ResourceRequests """ + model_config = {"arbitrary_types_allowed": True} + validate_access: Callable[["AccessControl"], bool] user: Optional[UserOutput] base_verb: Optional[ResourceAction] = None @@ -167,6 +177,10 @@ class AccessControlStore(BaseModel): indirect_requests: Set[ResourceRequest] = Field(default_factory=set) validation_request_count: int = 0 validation_results: Set[ResourceRequest] = Field(default_factory=set) + session: Optional[AsyncSession] = None # For RBAC permission checks (deprecated) + user_object: Optional["User"] = ( + None # Full User object with role_assignments loaded + ) def add_request(self, request: ResourceRequest): """ @@ -252,7 +266,9 @@ def raise_if_invalid_requests(self, show_denials: bool = True): def validate(self) -> Set[ResourceRequest]: """ - Checks with ACS and stores any returned invalid requests + Checks with ACS and stores any returned invalid requests. + + Now synchronous - authorization works on pre-loaded user data. """ self.validation_request_count += 1 @@ -262,8 +278,12 @@ def validate(self) -> Set[ResourceRequest]: direct_requests=deepcopy(self.direct_requests), indirect_requests=deepcopy(self.indirect_requests), validation_request_count=self.validation_request_count, + session=self.session, # Deprecated - kept for backward compat + user_id=self.user.id if self.user is not None else None, # Deprecated + user_object=self.user_object, # Pass full User object ) + # Call validate_access (now sync!) self.validate_access(access_control) # type: ignore self.validation_results = access_control.requests diff --git a/datajunction-server/tests/conftest.py b/datajunction-server/tests/conftest.py index 2728dbb20..e5ad2d0ca 100644 --- a/datajunction-server/tests/conftest.py +++ b/datajunction-server/tests/conftest.py @@ -944,6 +944,15 @@ async def create_default_user(session: AsyncSession) -> User: return user +@pytest_asyncio.fixture +async def default_user(session: AsyncSession): + """ + Create a default user for testing. + """ + user = await create_default_user(session) + yield user + + @pytest_asyncio.fixture(scope="module") async def module__client( request, diff --git a/datajunction-server/tests/internal/authorization_test.py b/datajunction-server/tests/internal/authorization_test.py new file mode 100644 index 000000000..e4b4d3fcf --- /dev/null +++ b/datajunction-server/tests/internal/authorization_test.py @@ -0,0 +1,1985 @@ +"""Tests for RBAC authorization logic.""" + +from datetime import datetime, timedelta, timezone + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from datajunction_server.database.group_member import GroupMember +from datajunction_server.database.rbac import Role, RoleAssignment, RoleScope +from datajunction_server.database.user import PrincipalKind, User +from datajunction_server.internal.access.authorization import ( + AccessDenialMode, + AuthContext, + PassthroughAuthorizationService, + RBACAuthorizationService, + authorize, + get_authorization_service, +) +from datajunction_server.errors import DJAuthorizationException +from datajunction_server.internal.access.authentication.basic import get_user +from datajunction_server.models.access import ( + Resource, + ResourceAction, + ResourceRequest, + ResourceType, +) +from datajunction_server.internal.access.group_membership import ( + GroupMembershipService, +) + + +class TestResourceMatching: + """Tests for wildcard pattern matching.""" + + def test_exact_match(self): + """Test exact string matching (no wildcards).""" + assert RBACAuthorizationService.resource_matches_pattern( + "finance.revenue", + "finance.revenue", + ) + assert not RBACAuthorizationService.resource_matches_pattern( + "finance.revenue", + "finance.cost", + ) + + def test_wildcard_all(self): + """Test the universal wildcard *.""" + assert RBACAuthorizationService.resource_matches_pattern("anything", "*") + assert RBACAuthorizationService.resource_matches_pattern( + "finance.revenue.quarterly", + "*", + ) + assert RBACAuthorizationService.resource_matches_pattern("", "*") + + def test_namespace_wildcard(self): + """Test namespace wildcard patterns.""" + # finance.* matches finance.revenue + assert RBACAuthorizationService.resource_matches_pattern( + "finance.revenue", + "finance.*", + ) + + # finance.* matches finance.quarterly.revenue + assert RBACAuthorizationService.resource_matches_pattern( + "finance.quarterly.revenue", + "finance.*", + ) + + # finance.* does NOT match finance (exact namespace) + assert not RBACAuthorizationService.resource_matches_pattern( + "finance", + "finance.*", + ) + + # finance.* does NOT match marketing.revenue + assert not RBACAuthorizationService.resource_matches_pattern( + "marketing.revenue", + "finance.*", + ) + + def test_nested_namespace_wildcard(self): + """Test nested namespace patterns.""" + # users.alice.* matches users.alice.dashboard + assert RBACAuthorizationService.resource_matches_pattern( + "users.alice.dashboard", + "users.alice.*", + ) + + # users.alice.* matches users.alice.metrics.revenue + assert RBACAuthorizationService.resource_matches_pattern( + "users.alice.metrics.revenue", + "users.alice.*", + ) + + # users.alice.* does NOT match users.bob.dashboard + assert not RBACAuthorizationService.resource_matches_pattern( + "users.bob.dashboard", + "users.alice.*", + ) + + +@pytest.mark.asyncio +class TestRBACPermissionChecks: + """Tests for RBAC permission checking.""" + + async def test_no_roles_returns_false( + self, + default_user: User, + session: AsyncSession, + ): + """Test that user with no roles gets False (no explicit rule).""" + user = await get_user(username=default_user.username, session=session) + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + + assert result is False + + async def test_explicit_grant_exact_match( + self, + default_user: User, + session: AsyncSession, + ): + """Test explicit permission grant with exact resource match.""" + # Create role with exact scope + role = Role( + name="test-role", + created_by_id=default_user.id, + ) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.revenue", + ) + session.add(scope) + + # Assign role to user + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Check permission + user = await get_user(username=default_user.username, session=session) + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result is True + + async def test_explicit_grant_wildcard_match( + self, + default_user: User, + session: AsyncSession, + ): + """Test permission grant via wildcard pattern.""" + # Create role with wildcard scope + role = Role( + name="finance-reader", + created_by_id=default_user.id, + ) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", # Wildcard + ) + session.add(scope) + + # Assign role + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Check permissions on various resources + user = await get_user(username=default_user.username, session=session) + result1 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result1 is True + + result2 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.quarterly.revenue", + ) + assert result2 is True + + # Different namespace - no match + result3 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="marketing.revenue", + ) + assert result3 is False + + async def test_wrong_action_no_match( + self, + default_user: User, + session: AsyncSession, + ): + """Test that wrong action doesn't grant permission.""" + role = Role(name="reader-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + # Only READ permission + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Check READ - should be granted + user = await get_user(username=default_user.username, session=session) + result_read = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result_read is True + + # Check WRITE - should be None (no explicit rule) + result_write = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.WRITE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result_write is False + + async def test_expired_assignment_ignored( + self, + default_user: User, + session: AsyncSession, + ): + """Test that expired role assignments are ignored.""" + role = Role(name="temp-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + # Assignment expired 1 hour ago + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + expires_at=datetime.now(timezone.utc) - timedelta(hours=1), + ) + session.add(assignment) + await session.commit() + + # Should not grant permission (expired) + user = await get_user(username=default_user.username, session=session) + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert not result + + async def test_multiple_roles_any_grants( + self, + default_user: User, + session: AsyncSession, + ): + """Test that having ANY role that grants permission is sufficient.""" + # Role 1: No matching scope + role1 = Role(name="marketing-role", created_by_id=default_user.id) + session.add(role1) + await session.flush() + + scope1 = RoleScope( + role_id=role1.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="marketing.*", + ) + session.add(scope1) + + assignment1 = RoleAssignment( + principal_id=default_user.id, + role_id=role1.id, + granted_by_id=default_user.id, + ) + session.add(assignment1) + + # Role 2: Matching scope + role2 = Role(name="finance-role", created_by_id=default_user.id) + session.add(role2) + await session.flush() + + scope2 = RoleScope( + role_id=role2.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope2) + + assignment2 = RoleAssignment( + principal_id=default_user.id, + role_id=role2.id, + granted_by_id=default_user.id, + ) + session.add(assignment2) + await session.commit() + + # Should grant because role2 matches + user = await get_user(username=default_user.username, session=session) + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result is True + + async def test_universal_wildcard( + self, + default_user: User, + session: AsyncSession, + ): + """Test that * wildcard grants access to everything.""" + role = Role(name="super-admin", created_by_id=default_user.id) + session.add(role) + await session.flush() + + # Universal wildcard + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Should grant for anything + user = await get_user(username=default_user.username, session=session) + result1 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result1 is True + + result2 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="anything.at.all", + ) + assert result2 is True + + +@pytest.mark.asyncio +class TestAuthorizationService: + """Tests for the synchronous AuthorizationService.""" + + async def test_passthrough_service_approves_all( + self, + default_user: User, + session: AsyncSession, + ): + """Test that PassthroughAuthorizationService approves everything.""" + # Get existing user + user = await get_user(username=default_user.username, session=session) + + service = PassthroughAuthorizationService() + + requests = [ + ResourceRequest( + verb=ResourceAction.WRITE, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NAMESPACE, + owner="", + ), + ), + ResourceRequest( + verb=ResourceAction.DELETE, + access_object=Resource( + name="secret.data", + resource_type=ResourceType.NODE, + owner="", + ), + ), + ] + + result = service.authorize(user, requests) # Now sync! + + assert len(result) == 2 + assert all(req.approved for req in result) + + async def test_rbac_service_with_permissions( + self, + session: AsyncSession, + default_user: User, + mocker, + ): + """Test RBACAuthorizationService with granted permissions.""" + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + role = Role(name="test-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + requests = [ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NAMESPACE, + owner="", + ), + ), + ResourceRequest( + verb=ResourceAction.WRITE, # Not granted + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NAMESPACE, + owner="", + ), + ), + ] + + result = await authorize( + session, + user, + requests, + on_denied=AccessDenialMode.RETURN, + ) + assert len(result) == 2 + assert result[0].approved is True # READ granted + assert result[1].approved is False # WRITE not granted + + async def test_get_authorization_service_factory(self, mocker): + """Test the factory function returns correct service.""" + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + service = get_authorization_service() + assert isinstance(service, RBACAuthorizationService) + + # Test passthrough provider + mock_settings.authorization_provider = "passthrough" + service = get_authorization_service() + + # Cached instance, so need to clear cache + assert isinstance(service, RBACAuthorizationService) + + # Clear LRU cache to test different provider + get_authorization_service.cache_clear() + service = get_authorization_service() + assert isinstance(service, PassthroughAuthorizationService) + + # Test unknown provider + mock_settings.authorization_provider = "unknown" + get_authorization_service.cache_clear() + with pytest.raises(ValueError) as exc_info: + get_authorization_service() + assert "unknown" in str(exc_info.value).lower() + assert "rbac" in str(exc_info.value).lower() + assert "passthrough" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +class TestGroupBasedPermissions: + """Tests for group-based role assignments.""" + + async def test_user_inherits_group_permissions( + self, + session: AsyncSession, + default_user: User, + mocker, + ): + """Test that users inherit permissions from groups they belong to.""" + # Create a group + group = User( + username="finance-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Add user to group + membership = GroupMember( + group_id=group.id, + member_id=default_user.id, + ) + session.add(membership) + await session.flush() + + # Create role and assign to group + role = Role(name="finance-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + # Assign role to group + assignment = RoleAssignment( + principal_id=group.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Expire the user object so we get a fresh load + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + # Check permission - should be granted via group + results = await authorize( + session=session, + user=user, + resource_requests=[ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.revenue.something", + resource_type=ResourceType.NODE, + owner="", + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + resource_type=ResourceType.NAMESPACE, + name="finance.revenue", + owner="", + ), + ), + ], + ) + assert results[0].approved is True + assert results[1].approved is True + + async def test_user_no_permission_without_group( + self, + session: AsyncSession, + default_user: User, + mocker, + ): + """Test that user without group membership doesn't get permission.""" + # Create a group with permissions + group = User( + username="marketing-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Create role and assign to GROUP + role = Role(name="marketing-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="marketing.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=group.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user without adding them to the group + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + # Check permission - should NOT be granted (user not in group) + results = await authorize( + session=session, + user=user, + resource_requests=[ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="marketing.revenue", + resource_type=ResourceType.NAMESPACE, + owner="", + ), + ), + ], + on_denied=AccessDenialMode.RETURN, + ) + assert results[0].approved is False + + +@pytest.mark.asyncio +class TestCrossResourceTypePermissions: + """Tests for namespace scopes covering nodes.""" + + async def test_namespace_scope_covers_nodes( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that namespace scope grants permission for nodes in that namespace.""" + # Create role with NAMESPACE scope + role = Role(name="finance-ns-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, # Namespace scope + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + + # Reload user with member_of and group role_assignments eagerly loaded + user = await get_user(username=default_user.username, session=session) + + # Check permission for NAMESPACE resource - should be granted + result_namespace = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result_namespace is True + + # Check permission for NODE resource in that namespace - should ALSO be granted! + result_node = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="finance.revenue.total", # Node in finance namespace + ) + assert result_node is True + + # Node in different namespace - should NOT be granted + result_other = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="marketing.revenue.total", + ) + assert result_other is False + + async def test_namespace_scope_nested_namespaces( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test namespace scope with nested namespaces.""" + # Create role with wildcard namespace scope + role = Role(name="finance-all", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", # finance.* + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # finance.quarterly.revenue node should match finance.* namespace + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="finance.quarterly.revenue", + ) + assert result is True + + async def test_node_scope_does_not_cover_namespace( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that NODE scope does NOT grant permission for NAMESPACE resources.""" + # Create role with NODE scope + role = Role(name="specific-node-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NODE, # NODE scope + scope_value="finance.revenue", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # Check permission for NODE - should be granted + result_node = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="finance.revenue", + ) + assert result_node is True + + # Check permission for NAMESPACE - should NOT be granted (cross-type only works one way) + result_namespace = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result_namespace is False + + +@pytest.mark.asyncio +class TestGlobalAccessScope: + """Tests for global access (empty or * scope_value).""" + + async def test_empty_scope_grants_global_access( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that empty scope_value grants access to all resources of that type.""" + # Create role with empty scope_value (global) + role = Role(name="global-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="", # Global! (empty string) + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # Should grant for any namespace + result1 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + assert result1 is True + + result2 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="marketing.anything", + ) + assert result2 is True + + # Should NOT grant for different resource type + result3 = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, # Different type + resource_name="finance.revenue", + ) + assert result3 is False + + async def test_star_scope_grants_global_access( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that "*" scope_value grants access to all resources of that type.""" + # Create role with "*" scope_value + role = Role(name="star-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NODE, + scope_value="*", # Wildcard for all + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # Should grant for any node + result = RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="anything.anywhere.node", + ) + assert result is True + + +@pytest.mark.asyncio +class TestPermissionHierarchy: + """Tests for permission hierarchy (MANAGE > DELETE > WRITE > READ).""" + + async def test_manage_implies_all_permissions( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that MANAGE permission grants all other permissions.""" + # Create role with MANAGE permission + role = Role(name="finance-manager", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.MANAGE, # Top-level permission + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # MANAGE should grant READ + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # MANAGE should grant WRITE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.WRITE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # MANAGE should grant DELETE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.DELETE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # MANAGE should grant EXECUTE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.EXECUTE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # MANAGE should grant MANAGE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.MANAGE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + async def test_write_implies_read( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that WRITE permission implies READ.""" + # Create role with WRITE permission + role = Role(name="finance-writer", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.WRITE, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # WRITE should grant READ + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # WRITE should grant WRITE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.WRITE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # WRITE should NOT grant DELETE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.DELETE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is False + ) + + async def test_read_does_not_imply_write( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that READ permission does NOT imply WRITE.""" + # Create role with only READ permission + role = Role(name="readonly-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # READ should grant READ + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is True + ) + + # READ should NOT grant WRITE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.WRITE, + resource_type=ResourceType.NAMESPACE, + resource_name="finance.revenue", + ) + is False + ) + + async def test_execute_implies_read( + self, + # client_with_basic: AsyncClient, + session: AsyncSession, + default_user: User, + ): + """Test that EXECUTE permission implies READ.""" + # Create role with EXECUTE permission + role = Role(name="query-executor", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.EXECUTE, + scope_type=ResourceType.NODE, + scope_value="finance.revenue", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with roles + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # EXECUTE should grant READ + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.READ, + resource_type=ResourceType.NODE, + resource_name="finance.revenue", + ) + is True + ) + + # EXECUTE should grant EXECUTE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.EXECUTE, + resource_type=ResourceType.NODE, + resource_name="finance.revenue", + ) + is True + ) + + # EXECUTE should NOT grant WRITE + assert ( + RBACAuthorizationService.has_permission( + assignments=user.role_assignments, + action=ResourceAction.WRITE, + resource_type=ResourceType.NODE, + resource_name="finance.revenue", + ) + is False + ) + + +@pytest.mark.asyncio +class TestAuthContext: + """Tests for AuthContext and effective assignments.""" + + async def test_auth_context_from_user_direct_assignments_only( + self, + default_user: User, + session: AsyncSession, + ): + """AuthContext includes user's direct role assignments.""" + # Create role and assign to user + role = Role(name="test-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user with assignments + user = await get_user(username=default_user.username, session=session) + + # Build AuthContext + auth_context = await AuthContext.from_user(session, user) + + assert auth_context.user_id == user.id + assert auth_context.username == user.username + assert len(auth_context.role_assignments) == 1 + assert auth_context.role_assignments[0].role.name == "test-role" + + async def test_auth_context_includes_group_assignments( + self, + default_user: User, + session: AsyncSession, + ): + """AuthContext flattens user's + groups' assignments.""" + # Create a group + group = User( + username="finance-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Add user to group + membership = GroupMember( + group_id=group.id, + member_id=default_user.id, + ) + session.add(membership) + + # Create role for user (direct) + user_role = Role(name="user-role", created_by_id=default_user.id) + session.add(user_role) + await session.flush() + + user_scope = RoleScope( + role_id=user_role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="personal.*", + ) + session.add(user_scope) + + user_assignment = RoleAssignment( + principal_id=default_user.id, + role_id=user_role.id, + granted_by_id=default_user.id, + ) + session.add(user_assignment) + + # Create role for group + group_role = Role(name="group-role", created_by_id=default_user.id) + session.add(group_role) + await session.flush() + + group_scope = RoleScope( + role_id=group_role.id, + action=ResourceAction.WRITE, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(group_scope) + + group_assignment = RoleAssignment( + principal_id=group.id, + role_id=group_role.id, + granted_by_id=default_user.id, + ) + session.add(group_assignment) + await session.commit() + + # Reload user + user = await get_user(username=default_user.username, session=session) + + # Build AuthContext (should include both) + auth_context = await AuthContext.from_user(session, user) + + assert auth_context.user_id == user.id + assert len(auth_context.role_assignments) == 2 # User's + group's + + role_names = {a.role.name for a in auth_context.role_assignments} + assert role_names == {"user-role", "group-role"} + + async def test_auth_context_with_multiple_groups( + self, + default_user: User, + session: AsyncSession, + ): + """User in multiple groups gets all group assignments.""" + # Create two groups + group1 = User( + username="finance-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + group2 = User( + username="data-eng-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add_all([group1, group2]) + await session.flush() + + # Add user to both groups + membership1 = GroupMember(group_id=group1.id, member_id=default_user.id) + membership2 = GroupMember(group_id=group2.id, member_id=default_user.id) + session.add_all([membership1, membership2]) + + # Give each group a role + role1 = Role(name="finance-role", created_by_id=default_user.id) + role2 = Role(name="data-eng-role", created_by_id=default_user.id) + session.add_all([role1, role2]) + await session.flush() + + scope1 = RoleScope( + role_id=role1.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + scope2 = RoleScope( + role_id=role2.id, + action=ResourceAction.WRITE, + scope_type=ResourceType.NAMESPACE, + scope_value="analytics.*", + ) + session.add_all([scope1, scope2]) + + assignment1 = RoleAssignment( + principal_id=group1.id, + role_id=role1.id, + granted_by_id=default_user.id, + ) + assignment2 = RoleAssignment( + principal_id=group2.id, + role_id=role2.id, + granted_by_id=default_user.id, + ) + session.add_all([assignment1, assignment2]) + await session.commit() + + # Reload user + user = await get_user(username=default_user.username, session=session) + + # Build AuthContext + auth_context = await AuthContext.from_user(session, user) + + # Should have assignments from both groups + assert len(auth_context.role_assignments) == 2 + role_names = {a.role.name for a in auth_context.role_assignments} + assert role_names == {"finance-role", "data-eng-role"} + + +@pytest.mark.asyncio +class TestCheckAccess: + """Tests for authorize() function with different denial modes.""" + + async def test_check_access_filter_mode_returns_only_approved( + self, + default_user: User, + session: AsyncSession, + mocker, + ): + """FILTER mode returns only approved requests (default).""" + # Give user access to finance.* but not marketing.* + role = Role(name="finance-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + # Reload user + await session.refresh(default_user) + user = await get_user(username=default_user.username, session=session) + + # Request access to 3 nodes: 2 accessible, 1 not + requests = [ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NODE, + owner="", + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.cost", + resource_type=ResourceType.NODE, + owner="", + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="marketing.revenue", + resource_type=ResourceType.NODE, + owner="", + ), + ), + ] + + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + # Check access (default FILTER mode) + approved = await authorize( + session, + user, + requests, + on_denied=AccessDenialMode.FILTER, + ) + + # Should only return the 2 approved (finance.* nodes) + assert len(approved) == 2 + assert all(req.approved for req in approved) + approved_names = {req.access_object.name for req in approved} + assert approved_names == {"finance.revenue", "finance.cost"} + + async def test_check_access_raise_mode_throws_on_denial( + self, + default_user: User, + session: AsyncSession, + mocker, + ): + """ + Raise mode throws DJAuthorizationException when access denied + for a user with no permissions. + """ + user = await get_user(username=default_user.username, session=session) + + request = ResourceRequest( + verb=ResourceAction.WRITE, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NODE, + owner="", + ), + ) + + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + with pytest.raises(DJAuthorizationException) as exc_info: + await authorize( + session, + user, + [request], + on_denied=AccessDenialMode.RAISE, + ) + + # Check exception message + assert "Access denied" in str(exc_info.value) + assert "WRITE" in str(exc_info.value) + assert "finance.revenue" in str(exc_info.value) + + async def test_check_access_raise_mode_succeeds_when_approved( + self, + default_user: User, + session: AsyncSession, + ): + """RAISE mode succeeds without exception when all approved.""" + # Give user access + role = Role(name="finance-writer", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.WRITE, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + request = ResourceRequest( + verb=ResourceAction.WRITE, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NODE, + owner="", + ), + ) + + # Should NOT raise + result = await authorize( + session, + user, + [request], + on_denied=AccessDenialMode.RAISE, + ) + + assert len(result) == 1 + assert result[0].approved is True + + async def test_check_access_return_mode( + self, + default_user: User, + session: AsyncSession, + mocker, + ): + """RETURN_ALL mode returns all requests with approved field set.""" + # Give user access to finance.* only + role = Role(name="finance-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Request access to 3 nodes: 2 accessible, 1 not + requests = [ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NODE, + owner="", + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.cost", + resource_type=ResourceType.NODE, + owner="", + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="marketing.revenue", + resource_type=ResourceType.NODE, + owner="", + ), + ), + ] + + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + # Check access with RETURN_ALL + all_requests = await authorize( + session, + user, + requests, + on_denied=AccessDenialMode.RETURN, + ) + + # Should return all 3 requests + assert len(all_requests) == 3 + + # 2 approved, 1 denied + approved = [r for r in all_requests if r.approved] + denied = [r for r in all_requests if not r.approved] + + assert len(approved) == 2 + assert len(denied) == 1 + assert denied[0].access_object.name == "marketing.revenue" + + +@pytest.mark.asyncio +class TestGetEffectiveAssignments: + """Tests for get_effective_assignments() with GroupMembershipService.""" + + async def test_effective_assignments_user_only( + self, + default_user: User, + session: AsyncSession, + ): + """User with no groups gets only direct assignments.""" + # Give user a direct assignment + role = Role(name="personal-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="personal.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Get effective assignments + assignments = await AuthContext.get_effective_assignments(session, user) + + assert len(assignments) == 1 + assert assignments[0].role.name == "personal-role" + + async def test_effective_assignments_with_postgres_groups( + self, + default_user: User, + session: AsyncSession, + ): + """Effective assignments includes groups from PostgresGroupMembershipService.""" + # Create group + group = User( + username="test-group", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Add user to group via GroupMember table + membership = GroupMember( + group_id=group.id, + member_id=default_user.id, + ) + session.add(membership) + + # Create role and assign to GROUP + role = Role(name="group-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.WRITE, + scope_type=ResourceType.NAMESPACE, + scope_value="shared.*", + ) + session.add(scope) + + group_assignment = RoleAssignment( + principal_id=group.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(group_assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Get effective assignments (should use PostgresGroupMembershipService by default) + assignments = await AuthContext.get_effective_assignments(session, user) + + # Should include group's assignment + assert len(assignments) >= 1 + role_names = {a.role.name for a in assignments} + assert "group-role" in role_names + + async def test_effective_assignments_with_custom_service( + self, + default_user: User, + session: AsyncSession, + mocker, + ): + """Custom GroupMembershipService can be provided.""" + + # Create a mock service that returns a specific group + class MockGroupService(GroupMembershipService): + name = "mock" + + async def is_user_in_group(self, session, username, group_name): + return group_name == "mock-group" + + async def get_user_groups(self, session, username): + return ["mock-group"] + + async def add_user_to_group(self, session, username, group_name): + pass + + async def remove_user_from_group(self, session, username, group_name): + pass + + # Create the mock group in DB + group = User( + username="mock-group", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Assign role to mock group + role = Role(name="mock-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.EXECUTE, + scope_type=ResourceType.NAMESPACE, + scope_value="special.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=group.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Use custom service + mock_service = MockGroupService() + assignments = await AuthContext.get_effective_assignments( + session, + user, + mock_service, + ) + + # Should include mock group's assignment + role_names = {a.role.name for a in assignments} + assert "mock-role" in role_names + + +@pytest.mark.asyncio +class TestCheckAccessIntegration: + """Integration tests for authorize() with real authorization flow.""" + + async def test_check_access_with_group_based_permissions( + self, + default_user: User, + session: AsyncSession, + ): + """End-to-end: User gets access via group membership.""" + # Create group + group = User( + username="data-team", + kind=PrincipalKind.GROUP, + oauth_provider="basic", + ) + session.add(group) + await session.flush() + + # Add user to group + membership = GroupMember( + group_id=group.id, + member_id=default_user.id, + ) + session.add(membership) + + # Give group permission + role = Role(name="data-team-role", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="data.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=group.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Request access to data.* node + request = ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="data.user_events", + resource_type=ResourceType.NODE, + owner="", + ), + ) + + # Should be approved via group + approved = await authorize(session, user, [request]) + + assert len(approved) == 1 + assert approved[0].approved is True + + async def test_check_access_with_mixed_approval( + self, + default_user: User, + session: AsyncSession, + mocker, + ): + """Some requests approved, some denied.""" + # Give access to finance.* only + role = Role(name="finance-reader", created_by_id=default_user.id) + session.add(role) + await session.flush() + + scope = RoleScope( + role_id=role.id, + action=ResourceAction.READ, + scope_type=ResourceType.NAMESPACE, + scope_value="finance.*", + ) + session.add(scope) + + assignment = RoleAssignment( + principal_id=default_user.id, + role_id=role.id, + granted_by_id=default_user.id, + ) + session.add(assignment) + await session.commit() + + user = await get_user(username=default_user.username, session=session) + + # Mix of accessible and inaccessible + requests = [ + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="finance.revenue", + resource_type=ResourceType.NODE, + owner="", + ), + ), + ResourceRequest( + verb=ResourceAction.READ, + access_object=Resource( + name="marketing.revenue", + resource_type=ResourceType.NODE, + owner="", + ), + ), + ] + mock_settings = mocker.patch( + "datajunction_server.internal.access.authorization.settings", + ) + mock_settings.authorization_provider = "rbac" + mock_settings.default_access_policy = "restrictive" + + # FILTER mode - returns only approved + filtered = await authorize( + session, + user, + requests, + on_denied=AccessDenialMode.FILTER, + ) + assert len(filtered) == 1 + assert filtered[0].access_object.name == "finance.revenue" + + # RETURN_ALL mode - returns both + all_results = await authorize( + session, + user, + requests, + on_denied=AccessDenialMode.RETURN, + ) + assert len(all_results) == 2 + assert all_results[0].approved is True + assert all_results[1].approved is False + + # RAISE mode - should raise + with pytest.raises(DJAuthorizationException): + await authorize( + session, + user, + requests, + on_denied=AccessDenialMode.RAISE, + ) From 885392f2d0435b09741fc21f274bbabf170f9e58 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 1 Dec 2025 18:43:28 -0800 Subject: [PATCH 2/4] Fix all endpoints and tests to use new setup with AccessChecker --- .../datajunction_server/api/cubes.py | 18 +- .../datajunction_server/api/data.py | 74 ++--- .../datajunction_server/api/deployments.py | 20 +- .../datajunction_server/api/dimensions.py | 79 +---- .../datajunction_server/api/djsql.py | 32 +- .../datajunction_server/api/helpers.py | 16 +- .../api/materializations.py | 12 +- .../datajunction_server/api/metrics.py | 16 +- .../datajunction_server/api/namespaces.py | 27 +- .../datajunction_server/api/nodes.py | 75 ++--- .../datajunction_server/api/sql.py | 238 ++++++++++++-- .../datajunction_server/api/system.py | 12 - .../datajunction_server/construction/build.py | 6 +- .../construction/build_v2.py | 78 ++--- .../construction/dimensions.py | 7 +- .../internal/access/authorization.py | 299 +++++++++++------- .../internal/caching/query_cache_manager.py | 49 +-- .../internal/deployment/utils.py | 2 - .../internal/materializations.py | 12 +- .../datajunction_server/internal/nodes.py | 55 ++-- .../datajunction_server/internal/sql.py | 48 ++- .../datajunction_server/models/access.py | 255 ++------------- datajunction-server/tests/api/access_test.py | 65 ++-- .../tests/api/dimensions_access_test.py | 40 ++- .../tests/api/namespaces_test.py | 102 +++--- datajunction-server/tests/api/sql_test.py | 92 +++--- datajunction-server/tests/conftest.py | 36 +-- .../tests/internal/authorization_test.py | 125 +++----- .../internal/deployment/orchestration_test.py | 1 - 29 files changed, 889 insertions(+), 1002 deletions(-) diff --git a/datajunction-server/datajunction_server/api/cubes.py b/datajunction-server/datajunction_server/api/cubes.py index c565c6c37..339b82f71 100644 --- a/datajunction-server/datajunction_server/api/cubes.py +++ b/datajunction-server/datajunction_server/api/cubes.py @@ -14,13 +14,15 @@ from datajunction_server.database.user import User from datajunction_server.errors import DJInvalidInputException from datajunction_server.internal.access.authentication.http import SecureAPIRouter -from datajunction_server.internal.access.authorization import validate_access +from datajunction_server.internal.access.authorization import ( + AccessChecker, + get_access_checker, +) from datajunction_server.internal.materializations import build_cube_materialization from datajunction_server.internal.nodes import ( get_all_cube_revisions_metadata, get_single_cube_revision_metadata, ) -from datajunction_server.models import access from datajunction_server.models.cube import ( CubeRevisionMetadata, DimensionValue, @@ -208,9 +210,7 @@ async def get_cube_dimension_sql( include_counts: bool = False, session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> TranslatedSQL: """ Generates SQL to retrieve all unique values of a dimension for the cube @@ -222,7 +222,7 @@ async def get_cube_dimension_sql( node_revision, dimensions, current_user, - validate_access, + access_checker, filters, limit, include_counts, @@ -251,9 +251,7 @@ async def get_cube_dimension_values( request: Request, query_service_client: QueryServiceClient = Depends(get_query_service_client), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> DimensionValues: """ All unique values of a dimension from the cube @@ -266,7 +264,7 @@ async def get_cube_dimension_values( cube, dimensions, current_user, - validate_access, + access_checker, filters, limit, include_counts, diff --git a/datajunction-server/datajunction_server/api/data.py b/datajunction-server/datajunction_server/api/data.py index 5baad69ee..089374ec3 100644 --- a/datajunction-server/datajunction_server/api/data.py +++ b/datajunction-server/datajunction_server/api/data.py @@ -29,9 +29,9 @@ ) from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( + AccessChecker, AccessDenialMode, - authorize, - validate_access, + get_access_checker, ) from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.models import access @@ -67,9 +67,7 @@ async def add_availability_state( *, session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), save_history: Callable = Depends(get_save_history), ) -> JSONResponse: """ @@ -91,17 +89,13 @@ async def add_availability_state( # Source nodes require that any availability states set are for one of the defined tables node_revision = node.current # type: ignore - await authorize( - session=session, - user=current_user, - resource_requests=[ - access.ResourceRequest( - verb=access.ResourceAction.WRITE, - access_object=access.Resource.from_node(node_revision), - ), - ], - on_denied=AccessDenialMode.RAISE, + access_checker.add_request( + access.ResourceRequest( + verb=access.ResourceAction.WRITE, + access_object=access.Resource.from_node(node_revision), + ), ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) if node.current.type == NodeType.SOURCE: # type: ignore if ( @@ -191,9 +185,7 @@ async def remove_availability_state( *, session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), save_history: Callable = Depends(get_save_history), ) -> JSONResponse: """ @@ -216,17 +208,13 @@ async def remove_availability_state( ), ) - await authorize( - session=session, - user=current_user, - resource_requests=[ - access.ResourceRequest( - verb=access.ResourceAction.WRITE, - access_object=access.Resource.from_node(node.current), # type: ignore - ), - ], - on_denied=AccessDenialMode.RAISE, + access_checker.add_request( + access.ResourceRequest( + verb=access.ResourceAction.WRITE, + access_object=access.Resource.from_node(node.current), # type: ignore + ), ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) # Save the old availability state for history record old_availability = ( @@ -282,10 +270,6 @@ async def get_data( query_service_client: QueryServiceClient = Depends(get_query_service_client), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), background_tasks: BackgroundTasks, cache: Cache = Depends(get_cache), ) -> QueryWithResults: @@ -311,8 +295,6 @@ async def get_data( engine_version=engine_version, use_materialized=use_materialized, ignore_errors=ignore_errors, - current_user=current_user, - validate_access=validate_access, ), ) @@ -363,10 +345,6 @@ async def get_data_stream_for_node( query_service_client: QueryServiceClient = Depends(get_query_service_client), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), background_tasks: BackgroundTasks, cache: Cache = Depends(get_cache), ) -> QueryWithResults: @@ -402,8 +380,6 @@ async def get_data_stream_for_node( engine_version=engine_version, use_materialized=True, ignore_errors=False, - current_user=current_user, - validate_access=validate_access, ), ) query_create = QueryCreate( @@ -469,10 +445,6 @@ async def get_data_for_metrics( query_service_client: QueryServiceClient = Depends(get_query_service_client), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), cache: Cache = Depends(get_cache), background_tasks: BackgroundTasks, ) -> QueryWithResults: @@ -498,8 +470,6 @@ async def get_data_for_metrics( engine_version=engine_version, use_materialized=True, ignore_errors=False, - current_user=current_user, - validate_access=validate_access, ), ) node = cast( @@ -545,20 +515,12 @@ async def get_data_stream_for_metrics( engine_name: Optional[str] = None, engine_version: Optional[str] = None, current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> QueryWithResults: """ Return data for a set of metrics with dimensions and filters using server sent events """ request_headers = dict(request.headers) - access_control = access.AccessControlStore( - validate_access=validate_access, - user=current_user, - base_verb=access.ResourceAction.READ, - ) - translated_sql, engine, catalog = await build_sql_for_multiple_metrics( session, metrics, @@ -568,7 +530,7 @@ async def get_data_stream_for_metrics( limit, engine_name, engine_version, - access_control, + access_checker, ) query_create = QueryCreate( diff --git a/datajunction-server/datajunction_server/api/deployments.py b/datajunction-server/datajunction_server/api/deployments.py index 4ec7cc945..03b05bb89 100644 --- a/datajunction-server/datajunction_server/api/deployments.py +++ b/datajunction-server/datajunction_server/api/deployments.py @@ -24,7 +24,9 @@ from datajunction_server.internal.deployment.utils import DeploymentContext from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( - validate_access, + AccessChecker, + AccessDenialMode, + get_access_checker, ) from datajunction_server.models import access from datajunction_server.models.deployment import DeploymentStatus @@ -158,22 +160,30 @@ async def create_deployment( current_user: User = Depends(get_current_user), query_service_client: QueryServiceClient = Depends(get_query_service_client), cache: Cache = Depends(get_cache), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> DeploymentInfo: """ This endpoint takes a deployment specification (namespace, nodes, tags), topologically sorts and validates the deployable objects, and deploys the nodes in parallel where possible. It returns a summary of the deployment. """ + access_checker.add_request( + access.ResourceRequest( + verb=access.ResourceAction.WRITE, + access_object=access.Resource( + resource_type=access.ResourceType.NAMESPACE, + name=deployment_spec.namespace, + ), + ), + ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + deployment_id = await executor.submit( spec=deployment_spec, context=DeploymentContext( current_user=current_user, request=request, query_service_client=query_service_client, - validate_access=validate_access, background_tasks=background_tasks, cache=cache, ), diff --git a/datajunction-server/datajunction_server/api/dimensions.py b/datajunction-server/datajunction_server/api/dimensions.py index 43e35f46c..0bb8a341e 100644 --- a/datajunction-server/datajunction_server/api/dimensions.py +++ b/datajunction-server/datajunction_server/api/dimensions.py @@ -11,12 +11,11 @@ from datajunction_server.models.node import NodeNameOutput from datajunction_server.api.helpers import get_node_by_name from datajunction_server.api.nodes import list_nodes -from datajunction_server.database.user import User +from datajunction_server.database.node import Node from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( - AccessDenialMode, - validate_access, - authorize, + AccessChecker, + get_access_checker, ) from datajunction_server.models import access from datajunction_server.models.node import NodeIndegreeOutput @@ -26,7 +25,6 @@ get_nodes_with_common_dimensions, ) from datajunction_server.utils import ( - get_current_user, get_session, get_settings, ) @@ -41,10 +39,7 @@ async def list_dimensions( prefix: Optional[str] = None, *, session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeIndegreeOutput]: """ List all available dimensions. @@ -53,8 +48,7 @@ async def list_dimensions( node_type=NodeType.DIMENSION, prefix=prefix, session=session, - current_user=current_user, - validate_access=validate_access, + access_checker=access_checker, ) node_indegrees = await get_dimension_dag_indegree(session, node_names) return sorted( @@ -72,49 +66,22 @@ async def find_nodes_with_dimension( *, node_type: List[NodeType] = Query([]), session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeNameOutput]: """ List all nodes that have the specified dimension """ - dimension_node = await get_node_by_name(session, name) - - # Ensure the user has access to the dimension node first - await authorize( - session=session, - user=current_user, - resource_requests=[ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource.from_node(dimension_node), - ) - ], - raise_on_denied=True, - ) + dimension_node = await Node.get_by_name(session, name) + access_checker.add_node(dimension_node, access.ResourceAction.READ) nodes = await get_nodes_with_common_dimensions( session, [dimension_node], node_type if node_type else None, ) - - # Only return nodes the user has access to - approvals = await authorize( - session=session, - user=current_user, - resource_requests=[ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource.from_node(node), - ) - for node in nodes - ], - on_denied=AccessDenialMode.FILTER, - ) - return [NodeNameOutput(name=node.name) for node in nodes if node.name in approvals] + access_checker.add_nodes(nodes, access.ResourceAction.READ) + approved_nodes = await access_checker.approved_resource_names() + return [node for node in nodes if node.name in approved_nodes] @router.get("/dimensions/common/", response_model=List[NodeNameOutput]) @@ -123,10 +90,7 @@ async def find_nodes_with_common_dimensions( node_type: List[NodeType] = Query([]), *, session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeNameOutput]: """ Find all nodes that have the list of common dimensions @@ -136,19 +100,6 @@ async def find_nodes_with_common_dimensions( [await get_node_by_name(session, dim) for dim in dimension], # type: ignore node_type, ) - resource_requests = await authorize( - session=session, - user=current_user, - resource_requests=[ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource.from_node(node), - ) - for node in nodes - ], - on_denied=AccessDenialMode.FILTER, - ) - approved_resource_names = [ - request.access_object.name for request in resource_requests - ] - return [NodeNameOutput(name=node.name) for node in nodes if node.name in approved_resource_names] + access_checker.add_nodes(nodes, access.ResourceAction.READ) + approved_resource_names = await access_checker.approved_resource_names() + return [node for node in nodes if node.name in approved_resource_names] diff --git a/datajunction-server/datajunction_server/api/djsql.py b/datajunction-server/datajunction_server/api/djsql.py index 6f01edb84..3de60edf2 100644 --- a/datajunction-server/datajunction_server/api/djsql.py +++ b/datajunction-server/datajunction_server/api/djsql.py @@ -9,14 +9,14 @@ from sse_starlette.sse import EventSourceResponse from datajunction_server.api.helpers import build_sql_for_dj_query, query_event_stream -from datajunction_server.database.user import User from datajunction_server.internal.access.authentication.http import SecureAPIRouter -from datajunction_server.internal.access.authorization import validate_access -from datajunction_server.models import access +from datajunction_server.internal.access.authorization import ( + AccessChecker, + get_access_checker, +) from datajunction_server.models.query import QueryCreate, QueryWithResults from datajunction_server.service_clients import QueryServiceClient from datajunction_server.utils import ( - get_current_user, get_query_service_client, get_session, get_settings, @@ -36,24 +36,16 @@ async def get_data_for_djsql( query_service_client: QueryServiceClient = Depends(get_query_service_client), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> QueryWithResults: """ Return data for a DJ SQL query """ request_headers = dict(request.headers) - access_control = access.AccessControlStore( - validate_access=validate_access, - user=current_user, - base_verb=access.ResourceAction.EXECUTE, - ) translated_sql, engine, catalog = await build_sql_for_dj_query( session, query, - access_control, + access_checker, engine_name, engine_version, ) @@ -86,24 +78,16 @@ async def get_data_stream_for_djsql( query_service_client: QueryServiceClient = Depends(get_query_service_client), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> QueryWithResults: # pragma: no cover """ Return data for a DJ SQL query using server side events """ request_headers = dict(request.headers) - access_control = access.AccessControlStore( - validate_access=validate_access, - user=current_user, - base_verb=access.ResourceAction.EXECUTE, - ) translated_sql, engine, catalog = await build_sql_for_dj_query( session, query, - access_control, + access_checker, engine_name, engine_version, ) diff --git a/datajunction-server/datajunction_server/api/helpers.py b/datajunction-server/datajunction_server/api/helpers.py index 5305d97e6..96d4ab41f 100644 --- a/datajunction-server/datajunction_server/api/helpers.py +++ b/datajunction-server/datajunction_server/api/helpers.py @@ -18,6 +18,10 @@ from sqlalchemy.orm import defer, joinedload, selectinload from sqlalchemy.sql.operators import and_, is_ +from datajunction_server.internal.access.authorization import ( + AccessChecker, + AccessDenialMode, +) from datajunction_server.api.notifications import get_notifier from datajunction_server.construction.build import ( get_default_criteria, @@ -207,7 +211,8 @@ async def get_query( orderby: List[str], limit: Optional[int] = None, engine: Optional[Engine] = None, - access_control: Optional[access.AccessControlStore] = None, + *, + access_checker: AccessChecker, use_materialized: bool = True, query_parameters: Optional[Dict[str, str]] = None, ignore_errors: bool = True, @@ -227,7 +232,7 @@ async def get_query( if ignore_errors: query_builder.ignore_errors() query_ast = await ( - query_builder.with_access_control(access_control) + query_builder.with_access_control(access_checker) .with_build_criteria(build_criteria) .add_dimensions(dimensions) .add_filters(filters) @@ -776,7 +781,7 @@ async def query_event_stream( async def build_sql_for_dj_query( # pragma: no cover session: AsyncSession, query: str, - access_control: access.AccessControl, + access_checker: AccessChecker, engine_name: Optional[str] = None, engine_version: Optional[str] = None, ) -> Tuple[TranslatedSQL, Engine, Catalog]: @@ -787,11 +792,12 @@ async def build_sql_for_dj_query( # pragma: no cover query_ast, dj_nodes = await build_dj_query(session, query) for node in dj_nodes: # pragma: no cover - access_control.add_request_by_node( # pragma: no cover + access_checker.add_node( # pragma: no cover node.current, + access.ResourceAction.READ, ) - access_control.validate_and_raise() # pragma: no cover + await access_checker.check(on_denied=AccessDenialMode.RAISE) # pragma: no cover leading_metric_node = dj_nodes[0] # pragma: no cover available_engines = ( # pragma: no cover diff --git a/datajunction-server/datajunction_server/api/materializations.py b/datajunction-server/datajunction_server/api/materializations.py index d2399729d..6bfbc0cc7 100644 --- a/datajunction-server/datajunction_server/api/materializations.py +++ b/datajunction-server/datajunction_server/api/materializations.py @@ -22,14 +22,16 @@ from datajunction_server.database.user import User from datajunction_server.errors import DJDoesNotExistException, DJInvalidInputException from datajunction_server.internal.access.authentication.http import SecureAPIRouter -from datajunction_server.internal.access.authorization import validate_access +from datajunction_server.internal.access.authorization import ( + AccessChecker, + get_access_checker, +) from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.internal.materializations import ( create_new_materialization, schedule_materialization_jobs, ) from datajunction_server.materialization.jobs import MaterializationJob -from datajunction_server.models import access from datajunction_server.models.base import labelize from datajunction_server.models.cube_materialization import UpsertCubeMaterialization from datajunction_server.models.node import AvailabilityStateInfo @@ -97,9 +99,7 @@ async def upsert_materialization( query_service_client: QueryServiceClient = Depends(get_query_service_client), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Add or update a materialization of the specified node. If a node_name is specified @@ -136,7 +136,7 @@ async def upsert_materialization( session, current_revision, materialization, - validate_access, # type: ignore + access_checker, # type: ignore current_user=current_user, ) diff --git a/datajunction-server/datajunction_server/api/metrics.py b/datajunction-server/datajunction_server/api/metrics.py index 417d4c344..206fb54b2 100644 --- a/datajunction-server/datajunction_server/api/metrics.py +++ b/datajunction-server/datajunction_server/api/metrics.py @@ -13,13 +13,14 @@ from datajunction_server.api.nodes import list_nodes from datajunction_server.database.node import Node, NodeRevision -from datajunction_server.database.user import User from datajunction_server.errors import DJError, DJInvalidInputException, ErrorCode from datajunction_server.internal.caching.cachelib_cache import get_cache from datajunction_server.internal.caching.interface import Cache from datajunction_server.internal.access.authentication.http import SecureAPIRouter -from datajunction_server.internal.access.authorization import validate_access -from datajunction_server.models import access +from datajunction_server.internal.access.authorization import ( + AccessChecker, + get_access_checker, +) from datajunction_server.models.metric import Metric from datajunction_server.models.node import ( DimensionAttributeOutput, @@ -30,7 +31,6 @@ from datajunction_server.models.node_type import NodeType from datajunction_server.sql.dag import get_dimensions, get_shared_dimensions from datajunction_server.utils import ( - get_current_user, get_session, get_settings, ) @@ -67,10 +67,7 @@ async def list_metrics( prefix: Optional[str] = None, *, session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), cache: Cache = Depends(get_cache), background_tasks: BackgroundTasks, ) -> List[str]: @@ -83,8 +80,7 @@ async def list_metrics( node_type=NodeType.METRIC, prefix=prefix, session=session, - current_user=current_user, - validate_access=validate_access, + access_checker=access_checker, ) background_tasks.add_task(cache.set, "metrics", metrics) return metrics diff --git a/datajunction-server/datajunction_server/api/namespaces.py b/datajunction-server/datajunction_server/api/namespaces.py index 69dfea142..6aeba17cf 100644 --- a/datajunction-server/datajunction_server/api/namespaces.py +++ b/datajunction-server/datajunction_server/api/namespaces.py @@ -19,8 +19,8 @@ from datajunction_server.models.dimensionlink import LinkType from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( - AccessDenialMode, - authorize, + AccessChecker, + get_access_checker, ) from datajunction_server.internal.namespaces import ( create_namespace, @@ -109,28 +109,17 @@ async def create_node_namespace( ) async def list_namespaces( session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NamespaceOutput]: """ List namespaces with the number of nodes contained in them """ results = await NodeNamespace.get_all_with_node_count(session) - resource_requests = [ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource.from_namespace(record.namespace), - ) - for record in results - ] - approved_namespaces = [ - request.access_object.name - for request in await authorize( - session=session, - user=current_user, - resource_requests=resource_requests, - on_denied=AccessDenialMode.FILTER, - ) - ] + access_checker.add_namespaces( + [record.namespace for record in results], + access.ResourceAction.READ, + ) + approved_namespaces = await access_checker.approved_resource_names() return [ NamespaceOutput(namespace=record.namespace, num_nodes=record.num_nodes) for record in results diff --git a/datajunction-server/datajunction_server/api/nodes.py b/datajunction-server/datajunction_server/api/nodes.py index 60c2add30..0d0effbf2 100644 --- a/datajunction-server/datajunction_server/api/nodes.py +++ b/datajunction-server/datajunction_server/api/nodes.py @@ -43,9 +43,8 @@ ) from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( - AccessDenialMode, - authorize, - validate_access, + AccessChecker, + get_access_checker, ) from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.internal.nodes import ( @@ -239,25 +238,14 @@ async def list_nodes( prefix: Optional[str] = None, *, session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[str]: """ List the available nodes. """ nodes = await Node.find(session, prefix, node_type) # type: ignore - approved_requests = await authorize( - session=session, - user=current_user, - resource_requests=[ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource.from_node(node), - ) - for node in nodes - ], - on_denied=AccessDenialMode.FILTER, - ) - return [req.access_object.name for req in approved_requests] + access_checker.add_nodes(nodes, access.ResourceAction.READ) + return await access_checker.approved_resource_names() @router.get("/nodes/details/", response_model=List[NodeIndexItem]) @@ -266,7 +254,7 @@ async def list_all_nodes_with_details( node_type: Optional[NodeType] = None, *, session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeIndexItem]: """ List the available nodes. @@ -296,24 +284,17 @@ async def list_all_nodes_with_details( "%s limit reached when returning all nodes, all nodes may not be captured in results", settings.node_list_max, ) - approvals = [ - request.access_object.name - for request in await authorize( - session=session, - user=current_user, - resource_requests=[ - access.ResourceRequest( - verb=access.ResourceAction.READ, - access_object=access.Resource( - name=row.name, - resource_type=access.ResourceType.NODE, - owner="", - ), - ) - for row in results - ], + for row in results: + access_checker.add_request( + access.ResourceRequest( + verb=access.ResourceAction.READ, + access_object=access.Resource( + name=row.name, + resource_type=access.ResourceType.NODE, + ), + ), ) - ] + approvals = await access_checker.approved_resource_names() return [row for row in results if row.name in approvals] @@ -441,9 +422,7 @@ async def create_source( current_user: User = Depends(get_current_user), request: Request, query_service_client: QueryServiceClient = Depends(get_query_service_client), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), background_tasks: BackgroundTasks, save_history: Callable = Depends(get_save_history), ) -> NodeOutput: @@ -457,7 +436,7 @@ async def create_source( session=session, current_user=current_user, query_service_client=query_service_client, - validate_access=validate_access, + access_checker=access_checker, background_tasks=background_tasks, save_history=save_history, ) @@ -489,9 +468,7 @@ async def create_node( current_user: User = Depends(get_current_user), query_service_client: QueryServiceClient = Depends(get_query_service_client), background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), save_history: Callable = Depends(get_save_history), cache: Cache = Depends(get_cache), ) -> NodeOutput: @@ -507,7 +484,7 @@ async def create_node( current_user=current_user, query_service_client=query_service_client, background_tasks=background_tasks, - validate_access=validate_access, + access_checker=access_checker, save_history=save_history, cache=cache, ) @@ -527,9 +504,7 @@ async def create_cube( query_service_client: QueryServiceClient = Depends(get_query_service_client), current_user: User = Depends(get_current_user), background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), save_history: Callable = Depends(get_save_history), ) -> NodeOutput: """ @@ -542,7 +517,7 @@ async def create_cube( current_user=current_user, query_service_client=query_service_client, background_tasks=background_tasks, - validate_access=validate_access, + access_checker=access_checker, save_history=save_history, ) @@ -1010,9 +985,7 @@ async def update_node( query_service_client: QueryServiceClient = Depends(get_query_service_client), current_user: User = Depends(get_current_user), background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), + access_checker: AccessChecker = Depends(get_access_checker), save_history: Callable = Depends(get_save_history), cache: Cache = Depends(get_cache), ) -> NodeOutput: @@ -1027,7 +1000,7 @@ async def update_node( query_service_client=query_service_client, current_user=current_user, background_tasks=background_tasks, - validate_access=validate_access, + access_checker=access_checker, request_headers=request_headers, save_history=save_history, refresh_materialization=refresh_materialization, diff --git a/datajunction-server/datajunction_server/api/sql.py b/datajunction-server/datajunction_server/api/sql.py index 12699c190..c16cc251a 100644 --- a/datajunction-server/datajunction_server/api/sql.py +++ b/datajunction-server/datajunction_server/api/sql.py @@ -26,11 +26,8 @@ from datajunction_server.internal.caching.interface import Cache from datajunction_server.database import Node from datajunction_server.database.queryrequest import QueryBuildType -from datajunction_server.database.user import User from datajunction_server.errors import DJInvalidInputException from datajunction_server.internal.access.authentication.http import SecureAPIRouter -from datajunction_server.internal.access.authorization import validate_access -from datajunction_server.models import access from datajunction_server.models.metric import TranslatedSQL, V3TranslatedSQL from datajunction_server.models.node_type import NodeType from datajunction_server.models.query import V3ColumnMetadata @@ -42,7 +39,6 @@ ) from datajunction_server.models.sql import GeneratedSQL from datajunction_server.utils import ( - get_current_user, get_session, get_settings, ) @@ -79,13 +75,8 @@ async def get_measures_sql_for_cube_v2( ), ), cache: Cache = Depends(get_cache), - session: AsyncSession = Depends(get_session), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: Optional[User] = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), use_materialized: bool = True, background_tasks: BackgroundTasks, request: Request, @@ -119,8 +110,6 @@ async def get_measures_sql_for_cube_v2( include_all_columns=include_all_columns, preaggregate=preaggregate, use_materialized=use_materialized, - current_user=current_user, - validate_access=validate_access, ), ) @@ -138,13 +127,8 @@ async def get_sql( limit: Optional[int] = None, query_params: str = Query("{}", description="Query parameters"), *, - session: AsyncSession = Depends(get_session), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), background_tasks: BackgroundTasks, ignore_errors: Optional[bool] = True, use_materialized: Optional[bool] = True, @@ -172,8 +156,6 @@ async def get_sql( engine_version=engine_version, use_materialized=use_materialized, ignore_errors=ignore_errors, - current_user=current_user, - validate_access=validate_access, ), ) @@ -392,6 +374,220 @@ async def get_metrics_sql_v3( ) +@router.get( + "/sql/measures/v3/", + response_model=MeasuresSQLResponse, + name="Get Measures SQL V3", + tags=["sql", "v3"], +) +async def get_measures_sql_v3( + metrics: List[str] = Query([]), + dimensions: List[str] = Query([]), + filters: List[str] = Query([]), + use_materialized: bool = Query(True), + *, + session: AsyncSession = Depends(get_session), + current_user: User = Depends(get_current_user), +) -> MeasuresSQLResponse: + """ + Generate pre-aggregated measures SQL for the requested metrics. + + Measures SQL represents the first stage of metric computation - it decomposes + each metric into its atomic aggregation components (e.g., SUM(amount), COUNT(*)) + and produces SQL that computes these components at the requested dimensional grain. + + Metrics are separated into grain groups, which represent sets of metrics that can be + computed together at a common grain. Each grain group produces its own SQL query, which + can be materialized independently to produce intermediate tables that are then queried + to compute final metric values. + + Returns: + One or more `GrainGroupSQL` objects, each containing: + - SQL query computing metric components at the specified grain + - Column metadata with semantic types + - Component details for downstream re-aggregation + + Args: + use_materialized: If True (default), use materialized tables when available. + Set to False when generating SQL for materialization refresh to avoid + circular references. + + See also: `/sql/metrics/v3/` for the final combined query with metric expressions. + """ + result = await build_measures_sql( + session=session, + metrics=metrics, + dimensions=dimensions, + filters=filters, + dialect=Dialect.SPARK, + use_materialized=use_materialized, + ) + + # Build a unified component_aliases map from all grain groups + # This maps component hash names -> actual SQL column aliases + all_component_aliases: dict[str, str] = {} + for gg in result.grain_groups: + all_component_aliases.update(gg.component_aliases) + + # Build metric formulas from decomposed metrics + metric_formulas = [] + for metric_name, decomposed in result.decomposed_metrics.items(): + # Get the combiner expression and rewrite component names to actual SQL aliases + from copy import deepcopy + + combiner_ast = deepcopy(decomposed.derived_ast.select.projection[0]) + + # Replace component hash names with actual SQL aliases in the combiner + for col in combiner_ast.find_all(ast.Column): + col_name = col.name.name if col.name else None + if col_name and col_name in all_component_aliases: + col.name = ast.Name(all_component_aliases[col_name]) + col._table = None + + combiner_str = str(combiner_ast) + + # Determine parent node name from the first grain group that contains this metric + parent_name = None + for gg in result.grain_groups: + if metric_name in gg.metrics: + parent_name = gg.parent_name + break + + # Check if this is a derived metric (references other metrics) + parent_names = result.ctx.parent_map.get(metric_name, []) + is_derived = decomposed.is_derived_for_parents( + parent_names, + result.ctx.nodes, + ) + + # Get component column names as they appear in SQL + # Use the unified component_aliases to resolve hash names -> actual aliases + component_names = [ + all_component_aliases.get(comp.name, comp.name) + for comp in decomposed.components + ] + + metric_formulas.append( + MetricFormulaResponse( + name=metric_name, + short_name=metric_name.split(".")[-1], + combiner=combiner_str, + components=component_names, + is_derived=is_derived, + parent_name=parent_name, + ), + ) + + return MeasuresSQLResponse( + grain_groups=[ + GrainGroupResponse( + sql=gg.sql, + columns=[ + V3ColumnMetadata( + name=col.name, + type=col.type, + semantic_entity=col.semantic_name, + semantic_type=col.semantic_type, + ) + for col in gg.columns + ], + grain=gg.grain, + aggregability=gg.aggregability.value + if hasattr(gg.aggregability, "value") + else str(gg.aggregability), + metrics=gg.metrics, + components=[ + ComponentResponse( + # Use actual SQL alias (metric short name for single-component, hash for multi) + name=gg.component_aliases.get(comp.name, comp.name), + expression=comp.expression, + aggregation=comp.aggregation, + merge=comp.merge, + aggregability=comp.rule.type.value + if hasattr(comp.rule.type, "value") + else str(comp.rule.type), + ) + for comp in gg.components + ], + parent_name=gg.parent_name, + ) + for gg in result.grain_groups + ], + metric_formulas=metric_formulas, + dialect=str(result.dialect) if result.dialect else None, + requested_dimensions=result.requested_dimensions, + ) + + +@router.get( + "/sql/metrics/v3/", + response_model=V3TranslatedSQL, + name="Get Metrics SQL V3", + tags=["sql", "v3"], +) +async def get_metrics_sql_v3( + metrics: List[str] = Query([]), + dimensions: List[str] = Query([]), + filters: List[str] = Query([]), + *, + session: AsyncSession = Depends(get_session), + current_user: User = Depends(get_current_user), +) -> V3TranslatedSQL: + """ + Generate final metrics SQL with fully computed metric expressions. + + Metrics SQL is the second (and final) stage of metric computation - it takes + the pre-aggregated components from Measures SQL and applies combiner expressions + to produce the actual metric values requested. + + - Metric components are re-aggregated as needed to match the requested + dimensional grain. + + - Derived metrics (defined as expressions over other metrics) + (e.g., `conversion_rate = order_count / visitor_count`) are computed by + substituting component references with their re-aggregated expressions. + + - When metrics come from different fact tables, their + grain groups are FULL OUTER JOINed on the common dimensions, with COALESCE + for dimension columns to handle NULLs from non-matching rows. + + - Dimension references in metric expressions are resolved to their + final column aliases. + + Returns: + A single SQL query that: + - Defines CTEs for each grain group (pre-aggregated component data) or + uses materialized pre-agg tables when available + - Joins grain groups on shared dimensions (if multiple) + - Builds dimensions with coalesce and metrics with combiner expressions + - Groups by dimensions to finalize re-aggregation + + See also: `/sql/measures/v3/` for the underlying pre-aggregated components. + """ + + result = await build_metrics_sql( + session=session, + metrics=metrics, + dimensions=dimensions, + filters=filters, + dialect=Dialect.SPARK, + ) + + return V3TranslatedSQL( + sql=result.sql, + columns=[ + V3ColumnMetadata( + name=col.name, + type=col.type, + semantic_entity=col.semantic_name, + semantic_type=col.semantic_type, + ) + for col in result.columns + ], + dialect=result.dialect, + ) + + @router.get("/sql/", response_model=TranslatedSQL, name="Get SQL For Metrics") async def get_sql_for_metrics( metrics: List[str] = Query([]), @@ -404,10 +600,6 @@ async def get_sql_for_metrics( session: AsyncSession = Depends(get_session), engine_name: Optional[str] = None, engine_version: Optional[str] = None, - current_user: User = Depends(get_current_user), - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), ignore_errors: Optional[bool] = True, use_materialized: Optional[bool] = True, background_tasks: BackgroundTasks, @@ -449,7 +641,5 @@ async def get_sql_for_metrics( engine_version=engine_version, use_materialized=use_materialized, ignore_errors=ignore_errors, - current_user=current_user, - validate_access=validate_access, ), ) diff --git a/datajunction-server/datajunction_server/api/system.py b/datajunction-server/datajunction_server/api/system.py index c8c3407ea..835cb1d85 100644 --- a/datajunction-server/datajunction_server/api/system.py +++ b/datajunction-server/datajunction_server/api/system.py @@ -7,10 +7,6 @@ from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncSession -from datajunction_server.internal.access.authorization import ( - validate_access, -) -from datajunction_server.models import access from datajunction_server.models.system import DimensionStats, RowOutput from datajunction_server.sql.dag import ( get_cubes_using_dimensions, @@ -18,12 +14,10 @@ ) from datajunction_server.internal.caching.cachelib_cache import get_cache from datajunction_server.internal.caching.interface import Cache -from datajunction_server.database.user import User from datajunction_server.database.node import Node from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.models.node_type import NodeType from datajunction_server.utils import ( - get_current_user, get_session, get_settings, ) @@ -63,11 +57,7 @@ async def get_data_for_system_metric( limit: int | None = None, session: AsyncSession = Depends(get_session), *, - current_user: User = Depends(get_current_user), background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn = Depends( - validate_access, - ), cache: Cache = Depends(get_cache), request: Request, ) -> list[list[RowOutput]]: @@ -94,8 +84,6 @@ async def get_data_for_system_metric( filters=filters, orderby=orderby, limit=limit, - current_user=current_user, - validate_access=validate_access, ), ) results = await session.execute(text(translated_sql.sql)) diff --git a/datajunction-server/datajunction_server/construction/build.py b/datajunction-server/datajunction_server/construction/build.py index 7588f61a0..6344214f5 100755 --- a/datajunction-server/datajunction_server/construction/build.py +++ b/datajunction-server/datajunction_server/construction/build.py @@ -12,7 +12,7 @@ from datajunction_server.database.node import Node, NodeRevision from datajunction_server.errors import DJError, DJInvalidInputException, ErrorCode from datajunction_server.internal.engines import get_engine -from datajunction_server.models import access +from datajunction_server.internal.access.authorization import AccessChecker from datajunction_server.models.cube_materialization import MetricComponent from datajunction_server.models.engine import Dialect from datajunction_server.models.materialization import GenericCubeConfig @@ -176,7 +176,7 @@ async def build_metric_nodes( engine_name: Optional[str] = None, engine_version: Optional[str] = None, build_criteria: Optional[BuildCriteria] = None, - access_control: Optional[access.AccessControlStore] = None, + access_checker: AccessChecker | None = None, ignore_errors: bool = True, query_parameters: Optional[dict[str, Any]] = None, ): @@ -214,7 +214,7 @@ async def build_metric_nodes( .order_by(orderby) .limit(limit) .with_build_criteria(build_criteria) - .with_access_control(access_control) + .with_access_control(access_checker) ) if ignore_errors: builder = builder.ignore_errors() diff --git a/datajunction-server/datajunction_server/construction/build_v2.py b/datajunction-server/datajunction_server/construction/build_v2.py index 7ab01224c..0fc8503b8 100644 --- a/datajunction-server/datajunction_server/construction/build_v2.py +++ b/datajunction-server/datajunction_server/construction/build_v2.py @@ -18,6 +18,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, selectinload +from datajunction_server.internal.access.authorization import ( + AccessChecker, + AccessDenialMode, +) from datajunction_server.construction.utils import to_namespaced_name from datajunction_server.database import Engine from datajunction_server.database.attributetype import ColumnAttribute @@ -274,7 +278,7 @@ def __init__( self._orderby: list[str] = [] self._limit: Optional[int] = None self._build_criteria: Optional[BuildCriteria] = self.get_default_criteria() - self._access_control: Optional[access.AccessControlStore] = None + self._access_checker: Optional[AccessChecker] = None self._ignore_errors: bool = False # The following attributes will be modified as the query gets built. @@ -384,14 +388,14 @@ def with_build_criteria(self, build_criteria: Optional[BuildCriteria] = None): def with_access_control( self, - access_control: Optional[access.AccessControlStore] = None, + access_checker: AccessChecker, ): """ Set access control for the query builder. """ - if access_control: # pragma: no cover - access_control.add_request_by_node(self.node_revision) - self._access_control = access_control + if access_checker: # pragma: no cover + access_checker.add_node(self.node_revision, access.ResourceAction.READ) + self._access_checker = access_checker return self @property @@ -572,7 +576,7 @@ async def build(self) -> ast.Query: ) # Error validation - self.validate_access() + await self.validate_access() if self.errors and not self._ignore_errors: raise DJQueryBuildException(errors=self.errors) return self.final_ast # type: ignore @@ -991,23 +995,18 @@ def set_dimension_aliases(self): node_col.set_semantic_entity(dim_name) node_col.set_semantic_type(SemanticType.DIMENSION) - async def add_request_by_node_name(self, node_name): + def add_request_by_node_name(self, node_name: str): """Add a node request to the access control validator.""" - if self._access_control: # pragma: no cover - # Use cached node if available to avoid DB lookup - cached_node = self.dependencies_cache.get(node_name) - if cached_node: # pragma: no cover - self._access_control.add_request_by_node(cached_node) - else: # pragma: no cover - await self._access_control.add_request_by_node_name( - self.session, - node_name, - ) + if self._access_checker: # pragma: no cover + self._access_checker.add_request_by_node_name( + node_name, + access.ResourceAction.READ, + ) - def validate_access(self): + async def validate_access(self): """Validates access""" - if self._access_control: - self._access_control.validate_and_raise() + if self._access_checker: + await self._access_checker.check(on_denied=AccessDenialMode.RAISE) async def find_dimension_node_joins( self, @@ -1059,7 +1058,7 @@ async def find_dimension_node_joins( # Build DimensionJoin objects using preloaded paths for attr in non_local_dimensions: - await self.add_request_by_node_name(attr.node_name) + self.add_request_by_node_name(attr.node_name) if attr.join_key not in dimension_node_joins: # Find matching path - try exact role match first, then no-role match @@ -1177,8 +1176,8 @@ def __init__( self._orderby: list[str] = [] self._limit: Optional[int] = None self._parameters: dict[str, ast.Value] = {} - self._build_criteria: Optional[BuildCriteria] = self.get_default_criteria() - self._access_control: Optional[access.AccessControlStore] = None + self._build_criteria: BuildCriteria | None = self.get_default_criteria() + self._access_checker: AccessChecker | None = None self._ignore_errors: bool = False # The following attributes will be modified as the query gets built. @@ -1293,14 +1292,13 @@ def with_build_criteria(self, build_criteria: Optional[BuildCriteria] = None): def with_access_control( self, - access_control: Optional[access.AccessControlStore] = None, + access_checker: AccessChecker, ): """ Set access control for the query builder. """ - if access_control: # pragma: no cover - access_control.add_request_by_nodes(self.metric_nodes) - self._access_control = access_control + access_checker.add_nodes(self.metric_nodes, access.ResourceAction.READ) + self._access_checker = access_checker return self @property @@ -1398,15 +1396,15 @@ async def build(self) -> ast.Query: self.final_ast.select.limit = ast.Number(value=self._limit) # Error validation - self.validate_access() + await self.validate_access() if self.errors and not self._ignore_errors: raise DJQueryBuildException(errors=self.errors) # pragma: no cover return self.final_ast - def validate_access(self): + async def validate_access(self): """Validates access""" - if self._access_control: - self._access_control.validate_and_raise() + if self._access_checker: + await self._access_checker.check(on_denied=AccessDenialMode.RAISE) async def build_measures_queries(self): """ @@ -1424,7 +1422,7 @@ async def build_measures_queries(self): if self._ignore_errors: query_builder = query_builder.ignore_errors() parent_ast = await ( - query_builder.with_access_control(self._access_control) + query_builder.with_access_control(self._access_checker) .with_build_criteria(self._build_criteria) .add_dimensions(self.dimensions) .add_filters(self.filters) @@ -1483,18 +1481,20 @@ async def build_metric_agg( """ Build the metric's aggregate expression. """ - if self._access_control: - self._access_control.add_request_by_node(metric_node) # type: ignore + if self._access_checker: + self._access_checker.add_node(metric_node, access.ResourceAction.READ) # type: ignore metric_query_builder = await QueryBuilder.create(self.session, metric_node) if self._ignore_errors: metric_query_builder = ( # pragma: no cover metric_query_builder.ignore_errors() ) - metric_query = await ( - metric_query_builder.with_access_control(self._access_control) - .with_build_criteria(self._build_criteria) - .build() - ) + if self._access_checker: + metric_query_builder = metric_query_builder.with_access_control( + self._access_checker, + ) + metric_query = await metric_query_builder.with_build_criteria( + self._build_criteria, + ).build() self.errors.extend(metric_query_builder.errors) metric_query.ctes[-1].select.projection[0].set_semantic_entity( # type: ignore f"{metric_node.name}.{amenable_name(metric_node.name)}", diff --git a/datajunction-server/datajunction_server/construction/dimensions.py b/datajunction-server/datajunction_server/construction/dimensions.py index dceb343fc..463c1db38 100644 --- a/datajunction-server/datajunction_server/construction/dimensions.py +++ b/datajunction-server/datajunction_server/construction/dimensions.py @@ -11,7 +11,7 @@ from datajunction_server.database.node import NodeRevision from datajunction_server.database.user import User from datajunction_server.errors import DJInvalidInputException -from datajunction_server.models import access +from datajunction_server.internal.access.authorization import AccessChecker from datajunction_server.models.column import SemanticType from datajunction_server.models.metric import TranslatedSQL from datajunction_server.models.query import ColumnMetadata @@ -27,7 +27,7 @@ async def build_dimensions_from_cube_query( cube: NodeRevision, dimensions: List[str], current_user: User, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, filters: Optional[str] = None, limit: Optional[int] = 50000, include_counts: bool = False, @@ -101,8 +101,7 @@ async def build_dimensions_from_cube_query( metrics=[metric.name for metric in cube.cube_metrics()], dimensions=dimensions, filters=[filters] if filters else [], - current_user=current_user, - validate_access=validate_access, + access_checker=access_checker, ) measures_query_ast = parse(measures_query[0].sql) measures_query_ast.bake_ctes() diff --git a/datajunction-server/datajunction_server/internal/access/authorization.py b/datajunction-server/datajunction_server/internal/access/authorization.py index ca06acbd7..e51d188d4 100644 --- a/datajunction-server/datajunction_server/internal/access/authorization.py +++ b/datajunction-server/datajunction_server/internal/access/authorization.py @@ -1,13 +1,10 @@ """ Authorization related functionality using pluggable services. -This module provides a pluggable authorization system that works with pre-loaded -user data (no async DB queries needed during authorization): - -- User's roles/assignments are eagerly loaded when fetching the user -- AuthorizationService performs sync in-memory checks -- No changes needed to existing API endpoints -- Keeps the existing validate_access() pattern +This module defines an abstract base class `AuthorizationService` for implementing different +authorization strategies. It includes built-in implementations such as +`RBACAuthorizationService` for role-based access control and `PassthroughAuthorizationService` +for permissive access. Example custom implementation: ```python @@ -20,6 +17,7 @@ def authorize(self, auth_context, requests): ``` """ +from fastapi import Depends from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timezone @@ -35,16 +33,22 @@ def authorize(self, auth_context, requests): GroupMembershipService, get_group_membership_service, ) +from datajunction_server.database.node import Node from datajunction_server.database.rbac import RoleAssignment from datajunction_server.database.user import User from datajunction_server.models.access import ( - AccessControl, + AccessDecision, + Resource, ResourceAction, ResourceRequest, ResourceType, - ValidateAccessFn, ) -from datajunction_server.utils import SEPARATOR, get_settings +from datajunction_server.utils import ( + SEPARATOR, + get_current_user, + get_session, + get_settings, +) settings = get_settings() @@ -59,11 +63,9 @@ class AccessDenialMode(Enum): How to handle denied access requests. """ - FILTER = "filter" # Return only approved requests (default for list operations) - RAISE = "raise" # Raise exception if any denied (for single resource operations) - RETURN = ( - "return" # Return all requests with approved field set (for custom handling) - ) + FILTER = "filter" # Return only approved requests + RAISE = "raise" # Raise exception if any denied + RETURN = "return" # Return all requests with approved field set # ============================================================================ @@ -184,92 +186,6 @@ async def get_effective_assignments( return assignments -async def authorize( - session: AsyncSession, - user: User, - resource_requests: List[ResourceRequest], - *, - on_denied: AccessDenialMode = AccessDenialMode.FILTER, -) -> List[ResourceRequest]: - """ - Check access to resources with flexible denial handling. - - Args: - session: Database session - user: User requesting access - resource_requests: Resources to check access for - on_denied: How to handle denied requests: - - FILTER (default): Return only approved requests (for list operations) - - RAISE: Raise DJAuthorizationException if any denied (for single resource) - - RETURN: Return with approved field set (for custom handling) - - Returns: - List of resource requests (filtered or all, depending on on_denied mode) - - Raises: - DJAuthorizationException: If on_denied=RAISE and any requests are denied - """ - auth_context = await AuthContext.from_user(session, user) - auth_service = get_authorization_service() - all_requests = auth_service.authorize(auth_context, resource_requests) - - # Handle based on mode - if on_denied == AccessDenialMode.RETURN: - return all_requests - elif on_denied == AccessDenialMode.RAISE: - denied = [r for r in all_requests if not r.approved] - if denied: - from datajunction_server.errors import ( - DJAuthorizationException, - DJError, - ErrorCode, - ) - - raise DJAuthorizationException( - message=f"Access denied to {len(denied)} resource(s)", - errors=[ - DJError( - code=ErrorCode.UNAUTHORIZED_ACCESS, - message=( - f"{r.verb.value.upper()} access denied to " - f"{r.access_object.resource_type.value}: " - f"{r.access_object.name}" - ), - ) - for r in denied - ], - ) - return all_requests - # Default: FILTER - return [r for r in all_requests if r.approved] - - -def validate_access() -> ValidateAccessFn: - """ - Default validate access function that uses the configured authorization service. - - This delegates to the pluggable service (RBAC, passthrough, custom, etc.) - using the AuthContext attached to the AccessControl object. - """ - auth_service = get_authorization_service() - - def _(access_control: AccessControl): - """ - Authorizes requests using the configured service. - """ - auth_context = getattr(access_control, "auth_context", None) - if not auth_context: - # No auth context - approve all (backward compat) - access_control.approve_all() - return - - # Use authorization service - requests_list = list(access_control.requests) - auth_service.authorize(auth_context, requests_list) - - return _ - - # ============================================================================ # New FastAPI-style Authorization Service # ============================================================================ @@ -295,8 +211,8 @@ class AuthorizationService(ABC): def authorize( self, auth_context: AuthContext, - requests: List[ResourceRequest], - ) -> List[ResourceRequest]: + requests: list[ResourceRequest], + ) -> list[AccessDecision]: """ Authorize resource requests for a user. @@ -359,8 +275,8 @@ class RBACAuthorizationService(AuthorizationService): def authorize( self, auth_context: AuthContext, - requests: List[ResourceRequest], - ) -> List[ResourceRequest]: + requests: list[ResourceRequest], + ) -> list[AccessDecision]: """ Authorize using pre-loaded RBAC roles and scopes (sync). @@ -371,17 +287,26 @@ def authorize( Returns: Same list of requests with approved=True/False set """ - for request in requests: - has_grant = self.has_permission( - assignments=auth_context.role_assignments, - action=request.verb, - resource_type=request.access_object.resource_type, - resource_name=request.access_object.name, - ) - request.approved = ( - has_grant or settings.default_access_policy == "permissive" - ) - return requests + return [self._make_decision(auth_context, request) for request in requests] + + def _make_decision( + self, + auth_context: AuthContext, + request: ResourceRequest, + ) -> AccessDecision: + """ + Convert ResourceRequest to AccessDecision. + """ + has_grant = self.has_permission( + assignments=auth_context.role_assignments, + action=request.verb, + resource_type=request.access_object.resource_type, + resource_name=request.access_object.name, + ) + return AccessDecision( + request=request, + approved=(has_grant or settings.default_access_policy == "permissive"), + ) @classmethod def resource_matches_pattern(cls, resource_name: str, pattern: str) -> bool: @@ -516,12 +441,10 @@ class PassthroughAuthorizationService(AuthorizationService): def authorize( self, auth_context: AuthContext, - requests: List[ResourceRequest], - ) -> List[ResourceRequest]: + requests: list[ResourceRequest], + ) -> list[AccessDecision]: """Approve all requests without checks (sync).""" - for request in requests: - request.approved = True - return requests + return [AccessDecision(request=request, approved=True) for request in requests] @lru_cache(maxsize=None) @@ -577,3 +500,137 @@ def authorize(self, user, requests): f"Unknown authorization_provider: '{provider}'. " f"Available providers: {available}", ) + + +async def get_auth_context( + session: AsyncSession = Depends(get_session), + current_user: User = Depends(get_current_user), +) -> AuthContext: + """Build authorization context with user + group assignments.""" + return await AuthContext.from_user(session, current_user) + + +class AccessChecker: + """Collects authorization requests and validates them.""" + + def __init__(self, auth_context: AuthContext): + self.auth_context = auth_context + self.requests: list[ResourceRequest] = [] + + def add_request(self, request: ResourceRequest): + """Add a request to check.""" + self.requests.append(request) + + def add_requests(self, requests: list[ResourceRequest]): + """Add requests to check.""" + self.requests.extend(requests) + + @classmethod + def resource_request_from_node( + cls, + node: Node, + action: ResourceAction, + ) -> ResourceRequest: + """Create ResourceRequest from a Node.""" + return ResourceRequest( + verb=action, + access_object=Resource.from_node(node), + ) + + def add_request_by_node_name(self, node_name: str, action: ResourceAction): + """Add request by node name.""" + self.requests.append( + ResourceRequest( + verb=action, + access_object=Resource(name=node_name, resource_type=ResourceType.NODE), + ), + ) + + def add_node(self, node: Node, action: ResourceAction): + """Add request for a node.""" + node_request = self.resource_request_from_node(node, action) + self.add_request(node_request) + + def add_nodes(self, nodes: list[Node], action: ResourceAction): + """Add requests for multiple nodes.""" + self.requests.extend( + self.resource_request_from_node(node, action) for node in nodes + ) + + @classmethod + def resource_request_from_namespace( + cls, + namespace: str, + action: ResourceAction, + ) -> ResourceRequest: + """Create ResourceRequest from a namespace.""" + return ResourceRequest( + verb=action, + access_object=Resource.from_namespace(namespace), + ) + + def add_namespace(self, namespace: str, action: ResourceAction): + """Add request for a namespace.""" + namespace_request = self.resource_request_from_namespace(namespace, action) + self.add_request(namespace_request) + + def add_namespaces(self, namespaces: list[str], action: ResourceAction): + """Add requests for multiple namespaces.""" + self.requests.extend( + self.resource_request_from_namespace(namespace, action) + for namespace in namespaces + ) + + async def check( + self, + on_denied: AccessDenialMode = AccessDenialMode.FILTER, + ) -> list[AccessDecision]: + """ + Validate all requests using AuthorizationService. + + Args: + on_denied: How to handle denied requests + - FILTER: Return only approved (default) + - RAISE: Raise exception if any denied + - RETURN_ALL: Return all with approved field set + """ + auth_service = get_authorization_service() + access_decisions = auth_service.authorize(self.auth_context, self.requests) + + if on_denied == AccessDenialMode.RETURN: + return access_decisions + elif on_denied == AccessDenialMode.RAISE: + denied: list[AccessDecision] = [ + decision for decision in access_decisions if not decision.approved + ] + if denied: + from datajunction_server.errors import DJAuthorizationException + + # Show first 5 denied resources + denied_names = [d.request.access_object.name for d in denied[:5]] + more_count = max(0, len(denied) - 5) + + raise DJAuthorizationException( + message=( + f"Access denied to {len(denied)} resource(s): " + f"{', '.join(denied_names)}" + + (f" and {more_count} more" if more_count else "") + ), + ) + return access_decisions + # Default: FILTER + return [decision for decision in access_decisions if decision.approved] + + async def approved_resource_names(self) -> list[str]: + """Get approved resource names.""" + return [ + decision.request.access_object.name + for decision in await self.check(on_denied=AccessDenialMode.FILTER) + ] + + +def get_access_checker( + auth_context: AuthContext = Depends(get_auth_context), +) -> AccessChecker: + """Provide AccessChecker with pre-loaded context.""" + return AccessChecker(auth_context) diff --git a/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py b/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py index 8e431d774..d63b1e57e 100644 --- a/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py +++ b/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py @@ -12,11 +12,14 @@ QueryBuildType, VersionedQueryKey, ) +from datajunction_server.internal.access.authorization import ( + AccessChecker, + AuthContext, + get_access_checker, +) from datajunction_server.internal.sql import build_sql_for_multiple_metrics -from datajunction_server.database.user import User -from datajunction_server.models import access from datajunction_server.models.sql import GeneratedSQL -from datajunction_server.utils import session_context, get_settings +from datajunction_server.utils import get_current_user, session_context, get_settings from datajunction_server.internal.sql import get_measures_query from datajunction_server.internal.sql import build_node_sql from datajunction_server.internal.engines import get_engine @@ -40,8 +43,6 @@ class QueryRequestParams: limit: int | None = None orderby: list[str] | None = None other_args: dict[str, Any] | None = None - current_user: User | None = None - validate_access: access.ValidateAccessFn | None = None include_all_columns: bool = False use_materialized: bool = False preaggregate: bool = False @@ -58,6 +59,17 @@ def __repr__(self): ) +async def build_access_checker_from_request( + request: Request, + session: AsyncSession, +) -> AccessChecker: + """Helper to build checker from request + session.""" + print("Building access checker from request") + current_user = await get_current_user(request) + auth_context = await AuthContext.from_user(session, current_user) + return get_access_checker(auth_context) + + class QueryCacheManager(RefreshAheadCacheManager): """ A generic manager for handling caching operations. @@ -85,39 +97,32 @@ async def fallback( """ params = deepcopy(params) async with session_context(request) as session: + access_checker = await build_access_checker_from_request(request, session) params.nodes = list(OrderedDict.fromkeys(params.nodes)) query_parameters = ( json.loads(params.query_params) if params.query_params else {} ) - access_control_store = ( - access.AccessControlStore( - validate_access=params.validate_access, - user=params.current_user, - base_verb=access.ResourceAction.READ, - ) - if params.validate_access - else None - ) match self.query_type: case QueryBuildType.MEASURES: return await self._build_measures_query( session, params, query_parameters, + access_checker, ) case QueryBuildType.NODE: return await self._build_node_query( session, params, query_parameters, - access_control_store, + access_checker, ) case QueryBuildType.METRICS: # pragma: no cover return await self._build_metrics_query( session, params, query_parameters, - access_control_store, + access_checker, ) async def build_cache_key( @@ -155,6 +160,7 @@ async def _build_measures_query( session: AsyncSession, params: QueryRequestParams, query_parameters: dict[str, Any], + access_checker: AccessChecker, ) -> list[GeneratedSQL]: return await get_measures_query( session=session, @@ -164,8 +170,7 @@ async def _build_measures_query( orderby=params.orderby or [], engine_name=params.engine_name, engine_version=params.engine_version, - current_user=params.current_user, - validate_access=params.validate_access, + access_checker=access_checker, include_all_columns=params.include_all_columns, use_materialized=params.use_materialized, preagg_requested=params.preaggregate, @@ -177,7 +182,7 @@ async def _build_node_query( session: AsyncSession, params: QueryRequestParams, query_parameters: dict[str, Any], - access_control_store: access.AccessControlStore | None = None, + access_checker: AccessChecker, ) -> TranslatedSQL: engine = ( await get_engine(session, params.engine_name, params.engine_version) # type: ignore @@ -195,7 +200,7 @@ async def _build_node_query( ignore_errors=params.ignore_errors, use_materialized=params.use_materialized, query_parameters=query_parameters, - access_control=access_control_store, + access_checker=access_checker, ) return TranslatedSQL.create( sql=built_sql.sql, @@ -208,7 +213,7 @@ async def _build_metrics_query( session: AsyncSession, params: QueryRequestParams, query_parameters: dict[str, Any], - access_control_store: access.AccessControlStore | None = None, + access_checker: AccessChecker, ) -> TranslatedSQL: built_sql, _, _ = await build_sql_for_multiple_metrics( session=session, @@ -219,7 +224,7 @@ async def _build_metrics_query( limit=params.limit, engine_name=params.engine_name, engine_version=params.engine_version, - access_control=access_control_store, + access_checker=access_checker, ignore_errors=params.ignore_errors, # type: ignore query_parameters=query_parameters, use_materialized=params.use_materialized, # type: ignore diff --git a/datajunction-server/datajunction_server/internal/deployment/utils.py b/datajunction-server/datajunction_server/internal/deployment/utils.py index 7e77c4bf7..c9aeac4ac 100644 --- a/datajunction-server/datajunction_server/internal/deployment/utils.py +++ b/datajunction-server/datajunction_server/internal/deployment/utils.py @@ -6,7 +6,6 @@ from datajunction_server.internal.caching.interface import Cache from datajunction_server.service_clients import QueryServiceClient from datajunction_server.database.user import User -from datajunction_server.models import access from datajunction_server.models.deployment import ( NodeSpec, CubeSpec, @@ -114,6 +113,5 @@ class DeploymentContext: current_user: User request: Request query_service_client: QueryServiceClient - validate_access: access.ValidateAccessFn background_tasks: BackgroundTasks cache: Cache diff --git a/datajunction-server/datajunction_server/internal/materializations.py b/datajunction-server/datajunction_server/internal/materializations.py index b6705da5b..48d25b247 100644 --- a/datajunction-server/datajunction_server/internal/materializations.py +++ b/datajunction-server/datajunction_server/internal/materializations.py @@ -20,7 +20,7 @@ build_cube_materialization, ) from datajunction_server.materialization.jobs import MaterializationJob -from datajunction_server.models import access +from datajunction_server.internal.access.authorization import AccessChecker from datajunction_server.models.column import SemanticType from datajunction_server.models.cube_materialization import UpsertCubeMaterialization from datajunction_server.models.materialization import ( @@ -104,7 +104,7 @@ async def build_cube_materialization_config( session: AsyncSession, current_revision: NodeRevision, upsert_input: UpsertMaterialization, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, current_user: User, ) -> DruidMeasuresCubeConfig: """ @@ -129,6 +129,7 @@ async def build_cube_materialization_config( metrics=[node.name for node in current_revision.cube_metrics()], dimensions=current_revision.cube_dimensions(), use_materialized=False, + access_checker=access_checker, ) generic_config = DruidMetricsCubeConfig( lookback_window=upsert_input.config.lookback_window, @@ -156,8 +157,7 @@ async def build_cube_materialization_config( metrics=[node.name for node in current_revision.cube_metrics()], dimensions=current_revision.cube_dimensions(), filters=[], - current_user=current_user, - validate_access=validate_access, + access_checker=access_checker, ) for measures_query in measures_queries: metrics_expressions = await rewrite_metrics_expressions( @@ -238,7 +238,7 @@ async def create_new_materialization( session: AsyncSession, current_revision: NodeRevision, upsert: UpsertCubeMaterialization | UpsertMaterialization, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, current_user: User, ) -> Materialization: """ @@ -284,7 +284,7 @@ async def create_new_materialization( session, current_revision, upsert, - validate_access, + access_checker, current_user=current_user, ) materialization_name = ( diff --git a/datajunction-server/datajunction_server/internal/nodes.py b/datajunction-server/datajunction_server/internal/nodes.py index e14eac31c..629522223 100644 --- a/datajunction-server/datajunction_server/internal/nodes.py +++ b/datajunction-server/datajunction_server/internal/nodes.py @@ -12,6 +12,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, selectinload + +from datajunction_server.internal.access.authorization import ( + AccessChecker, + AccessDenialMode, +) from datajunction_server.internal.caching.interface import Cache from datajunction_server.models.query import QueryCreate from datajunction_server.api.helpers import ( @@ -118,7 +123,7 @@ async def create_a_source_node( current_user: User, query_service_client: QueryServiceClient, background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, save_history: Callable, ): request_headers = dict(request.headers) @@ -132,7 +137,7 @@ async def create_a_source_node( current_user=current_user, request_headers=request_headers, query_service_client=query_service_client, - validate_access=validate_access, + access_checker=access_checker, background_tasks=background_tasks, save_history=save_history, ): @@ -212,7 +217,7 @@ async def create_a_node( current_user: User, query_service_client: QueryServiceClient, background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, save_history: Callable, cache: Cache, ) -> Node: @@ -231,7 +236,7 @@ async def create_a_node( request_headers=request_headers, query_service_client=query_service_client, background_tasks=background_tasks, - validate_access=validate_access, + access_checker=access_checker, save_history=save_history, cache=cache, ): @@ -303,7 +308,7 @@ async def create_a_cube( current_user: User, query_service_client: QueryServiceClient, background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, save_history: Callable, ) -> Node: request_headers = dict(request.headers) @@ -318,7 +323,7 @@ async def create_a_cube( request_headers=request_headers, query_service_client=query_service_client, background_tasks=background_tasks, - validate_access=validate_access, + access_checker=access_checker, save_history=save_history, ): return recreated_node # pragma: no cover @@ -887,7 +892,7 @@ async def update_any_node( current_user: User, save_history: Callable, background_tasks: BackgroundTasks = None, - validate_access: access.ValidateAccessFn = None, + access_checker: AccessChecker = None, refresh_materialization: bool = False, cache: Cache | None = None, ) -> Node: @@ -908,13 +913,17 @@ async def update_any_node( node = cast(Node, node) # Check that the user has access to modify this node - access_control = access.AccessControlStore( - validate_access=validate_access, - user=current_user, - base_verb=access.ResourceAction.WRITE, - ) - access_control.add_request_by_node(node) - access_control.validate_and_raise() + if access_checker: + access_checker.add_request( + access.ResourceRequest( + access_object=access.Resource( + resource_type=access.ResourceType.NODE, + name=node.name, + ), + verb=access.ResourceAction.WRITE, + ), + ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) if data.owners and data.owners != [owner.username for owner in node.owners]: await update_owners(session, node, data.owners, current_user, save_history) @@ -929,7 +938,7 @@ async def update_any_node( query_service_client=query_service_client, current_user=current_user, background_tasks=background_tasks, - validate_access=validate_access, # type: ignore + access_checker=access_checker, # type: ignore save_history=save_history, refresh_materialization=refresh_materialization, ) @@ -942,7 +951,7 @@ async def update_any_node( query_service_client=query_service_client, current_user=current_user, background_tasks=background_tasks, - validate_access=validate_access, # type: ignore + access_checker=access_checker, # type: ignore save_history=save_history, cache=cache, ) @@ -957,7 +966,7 @@ async def update_node_with_query( query_service_client: QueryServiceClient, current_user: User, background_tasks: BackgroundTasks, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, save_history: Callable, cache: Cache, ) -> Node: @@ -1049,7 +1058,7 @@ async def update_node_with_query( lookback_window=old.lookback_window, ) ), - validate_access, + access_checker, current_user=current_user, ), ) @@ -1193,7 +1202,7 @@ async def update_cube_node( query_service_client: QueryServiceClient, current_user: User, background_tasks: BackgroundTasks = None, - validate_access: access.ValidateAccessFn, + access_checker: AccessChecker, save_history: Callable, refresh_materialization: bool = False, ) -> Optional[NodeRevision]: @@ -1296,7 +1305,7 @@ async def update_cube_node( ), job=MaterializationJobTypeEnum.find_match(old.job).value.name, ), - validate_access, + access_checker, current_user=current_user, ), ) @@ -1509,7 +1518,7 @@ async def create_node_from_inactive( query_service_client: QueryServiceClient, save_history: Callable, background_tasks: BackgroundTasks = None, - validate_access: access.ValidateAccessFn = None, + access_checker: AccessChecker = None, cache: Cache | None = None, ) -> Optional[Node]: """ @@ -1559,7 +1568,7 @@ async def create_node_from_inactive( query_service_client=query_service_client, current_user=current_user, background_tasks=background_tasks, - validate_access=validate_access, # type: ignore + access_checker=access_checker, # type: ignore save_history=save_history, cache=cache, ) @@ -1572,7 +1581,7 @@ async def create_node_from_inactive( query_service_client=query_service_client, current_user=current_user, background_tasks=background_tasks, - validate_access=validate_access, # type: ignore + access_checker=access_checker, # type: ignore save_history=save_history, ) try: diff --git a/datajunction-server/datajunction_server/internal/sql.py b/datajunction-server/datajunction_server/internal/sql.py index b3d9e0b50..8ee18c082 100644 --- a/datajunction-server/datajunction_server/internal/sql.py +++ b/datajunction-server/datajunction_server/internal/sql.py @@ -5,6 +5,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload +from datajunction_server.internal.access.authorization import ( + AccessChecker, + AccessDenialMode, +) from datajunction_server.api.helpers import ( assemble_column_metadata, find_existing_cube, @@ -29,7 +33,6 @@ ) from datajunction_server.database import Engine from datajunction_server.database.node import Node, NodeRevision -from datajunction_server.database.user import User from datajunction_server.database.catalog import Catalog from datajunction_server.errors import DJInvalidInputException, DJException from datajunction_server.internal.engines import get_engine @@ -58,7 +61,8 @@ async def build_node_sql( orderby: list[str] | None = None, limit: int | None = None, engine: Engine | None = None, - access_control: access.AccessControlStore | None = None, + *, + access_checker: AccessChecker, ignore_errors: bool = True, use_materialized: bool = True, query_parameters: dict[str, Any] | None = None, @@ -95,7 +99,7 @@ async def build_node_sql( limit=limit, engine_name=engine.name if engine else None, engine_version=engine.version if engine else None, - access_control=access_control, + access_checker=access_checker, use_materialized=use_materialized, query_parameters=query_parameters, ) @@ -113,7 +117,7 @@ async def build_node_sql( limit, engine.name if engine else None, engine.version if engine else None, - access_control=access_control, + access_checker=access_checker, ignore_errors=ignore_errors, use_materialized=use_materialized, query_parameters=query_parameters, @@ -129,7 +133,7 @@ async def build_node_sql( orderby=orderby or [], limit=limit, engine=engine, - access_control=access_control, + access_checker=access_checker, use_materialized=use_materialized, query_parameters=query_parameters, ignore_errors=ignore_errors, @@ -156,7 +160,7 @@ async def build_sql_for_multiple_metrics( limit: int | None = None, engine_name: str | None = None, engine_version: str | None = None, - access_control: access.AccessControlStore | None = None, + access_checker: AccessChecker | None = None, ignore_errors: bool = True, use_materialized: bool = True, query_parameters: dict[str, str] | None = None, @@ -184,8 +188,8 @@ async def build_sql_for_multiple_metrics( ), ], ) - if access_control: - access_control.add_request_by_node(leading_metric_node.current) # type: ignore + if access_checker: + access_checker.add_node(leading_metric_node.current, access.ResourceAction.READ) # type: ignore available_engines = ( leading_metric_node.current.catalog.engines # type: ignore if leading_metric_node.current.catalog # type: ignore @@ -233,10 +237,9 @@ async def build_sql_for_multiple_metrics( validate_orderby(orderby, metrics, dimensions) if cube and cube.availability and use_materialized and materialized_cube_catalog: - if access_control: # pragma: no cover - access_control.add_request_by_node(cube) - access_control.state = access.AccessControlState.INDIRECT - access_control.raise_if_invalid_requests() + if access_checker: # pragma: no cover + access_checker.add_node(cube, access.ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) query_ast = build_materialized_cube_node( metric_columns, dimension_columns, @@ -284,10 +287,15 @@ async def build_sql_for_multiple_metrics( dimensions=dimensions or [], orderby=orderby or [], limit=limit, - access_control=access_control, + access_checker=access_checker, ignore_errors=ignore_errors, query_parameters=query_parameters, ) + + # Check authorization for all discovered nodes + if access_checker: + await access_checker.check(on_denied=AccessDenialMode.RAISE) + columns = [ assemble_column_metadata(col, use_semantic_metadata=True) # type: ignore for col in query_ast.select.projection @@ -322,8 +330,7 @@ async def get_measures_query( orderby: list[str] = None, engine_name: str | None = None, engine_version: str | None = None, - current_user: User | None = None, - validate_access: access.ValidateAccessFn = None, + access_checker: AccessChecker = None, include_all_columns: bool = False, use_materialized: bool = True, preagg_requested: bool = False, @@ -343,15 +350,6 @@ async def get_measures_query( build_criteria = BuildCriteria( dialect=engine.dialect if engine and engine.dialect else Dialect.SPARK, ) - access_control = ( - access.AccessControlStore( - validate_access=validate_access, - user=current_user, - base_verb=access.ResourceAction.READ, - ) - if validate_access - else None - ) if not filters: filters = [] @@ -405,7 +403,7 @@ async def get_measures_query( ) parent_ast = await ( query_builder.ignore_errors() - .with_access_control(access_control) + .with_access_control(access_checker) .with_build_criteria(build_criteria) .add_dimensions(dimensions) .add_filters(filters) diff --git a/datajunction-server/datajunction_server/models/access.py b/datajunction-server/datajunction_server/models/access.py index aa3727110..a74a99d9b 100644 --- a/datajunction-server/datajunction_server/models/access.py +++ b/datajunction-server/datajunction_server/models/access.py @@ -2,22 +2,10 @@ Models for authorization """ -from copy import deepcopy -from enum import Enum -from datajunction_server.typing import StrEnum -from typing import TYPE_CHECKING, Callable, Iterable, Optional, Set, Union - -from pydantic import BaseModel, Field -from sqlalchemy.ext.asyncio import AsyncSession +from dataclasses import dataclass -from datajunction_server.construction.utils import try_get_dj_node +from datajunction_server.typing import StrEnum from datajunction_server.database.node import Node, NodeRevision -from datajunction_server.errors import DJAuthorizationException, DJError, ErrorCode -from datajunction_server.models.user import UserOutput - -if TYPE_CHECKING: - from datajunction_server.database.user import User - from datajunction_server.sql.parsing.ast import Column class ResourceType(StrEnum): @@ -41,39 +29,36 @@ class ResourceAction(StrEnum): MANAGE = "manage" # Grant/revoke permissions (RBAC-specific) -class Resource(BaseModel): +@dataclass(frozen=True) +class Resource: """ Base class for resource objects that are passed to injected validation logic """ - name: str # name of the node + name: str resource_type: ResourceType - owner: str def __hash__(self) -> int: - return hash((self.name, self.resource_type, self.owner)) + return hash((self.name, self.resource_type)) @classmethod - def from_node(cls, node: Union[NodeRevision, Node]) -> "Resource": + def from_node(cls, node: NodeRevision | Node) -> "Resource": """ Create a resource object from a DJ Node """ - return cls(name=node.name, resource_type=ResourceType.NODE, owner="") + return cls(name=node.name, resource_type=ResourceType.NODE) @classmethod def from_namespace(cls, namespace: str) -> "Resource": """ Create a resource object from a namespace """ - return cls( - name=namespace, - resource_type=ResourceType.NAMESPACE, - owner="", - ) + return cls(name=namespace, resource_type=ResourceType.NAMESPACE) -class ResourceRequest(BaseModel): +@dataclass(frozen=True) +class ResourceRequest: """ Resource Requests provide the information that is available to grant access to a resource @@ -81,19 +66,6 @@ class ResourceRequest(BaseModel): verb: ResourceAction access_object: Resource - approved: Optional[bool] = None - - def approve(self): - """ - Approve the request - """ - self.approved = True - - def deny(self): - """ - Deny the request - """ - self.approved = False def __hash__(self) -> int: return hash((self.verb, self.access_object)) @@ -102,207 +74,24 @@ def __eq__(self, other) -> bool: return self.verb == other.verb and self.access_object == other.access_object def __str__(self) -> str: - return ( # pragma: no cover + return ( f"{self.verb.value}:" f"{self.access_object.resource_type.value}/" f"{self.access_object.name}" ) -class AccessControlState(Enum): - """ - State values used by the ACS function to track when - """ - - DIRECT = "direct" - INDIRECT = "indirect" - - -class AccessControl(BaseModel): +@dataclass(frozen=True) +class AccessDecision: """ - An access control provides all the information - necessary to deny or approve a request - """ - - model_config = {"arbitrary_types_allowed": True} - - user: str - state: AccessControlState - direct_requests: Set[ResourceRequest] - indirect_requests: Set[ResourceRequest] - validation_request_count: int - session: Optional[AsyncSession] = None # For RBAC permission checks (deprecated) - user_id: Optional[int] = None # User ID for RBAC lookups (deprecated) - user_object: Optional["User"] = ( - None # Full User object with role_assignments loaded - ) - - @property - def requests(self) -> Set[ResourceRequest]: - """ - Get all direct and indirect requests as a single set - """ - return self.direct_requests | self.indirect_requests - - def approve_all(self): - """ - Approve all requests - """ - for request in self.requests: - request.approve() - - def deny_all(self): - """ - Deny all requests - """ - for request in self.requests: - request.deny() + The result of an access control check for a resource request. - -ValidateAccessFn = Callable[[AccessControl], None] - - -class AccessControlStore(BaseModel): - """ - An access control store tracks all ResourceRequests + Attributes: + request: The resource request that was checked + approved: Whether access was granted + reason: Optional explanation if access was denied """ - model_config = {"arbitrary_types_allowed": True} - - validate_access: Callable[["AccessControl"], bool] - user: Optional[UserOutput] - base_verb: Optional[ResourceAction] = None - state: AccessControlState = AccessControlState.DIRECT - direct_requests: Set[ResourceRequest] = Field(default_factory=set) - indirect_requests: Set[ResourceRequest] = Field(default_factory=set) - validation_request_count: int = 0 - validation_results: Set[ResourceRequest] = Field(default_factory=set) - session: Optional[AsyncSession] = None # For RBAC permission checks (deprecated) - user_object: Optional["User"] = ( - None # Full User object with role_assignments loaded - ) - - def add_request(self, request: ResourceRequest): - """ - Add a resource request to the store - """ - if self.state == AccessControlState.DIRECT: - self.direct_requests.add(request) - else: - self.indirect_requests.add(request) # pragma: no cover - - async def add_request_by_node_name( - self, - session: AsyncSession, - node_name: Union[str, "Column"], - verb: Optional[ResourceAction] = None, - ): - """ - Add a request using a node's name - """ - node = await try_get_dj_node(session, node_name) - if node is not None: - self.add_request_by_node(node, verb) - return node - - def add_request_by_node( - self, - node: Union[NodeRevision, Node], - verb: Optional[ResourceAction] = None, - ): - """ - Add a request using a node - """ - self.add_request( - ResourceRequest( - verb=verb or self.base_verb, - access_object=Resource.from_node(node), - ), - ) - - def add_request_by_nodes( - self, - nodes: Iterable[Union[NodeRevision, Node]], - verb: Optional[ResourceAction] = None, - ): - """ - Add a request using a node - """ - for node in nodes: # pragma: no cover - self.add_request( # pragma: no cover - ResourceRequest( - verb=verb or self.base_verb, - access_object=Resource.from_node(node), - ), - ) - - def raise_if_invalid_requests(self, show_denials: bool = True): - """ - Raises if validate has ever given any invalid requests - """ - denied = ", ".join( - [ - str(request) - for request in self.validation_results - if not request.approved - ], - ) - if denied: - message = ( - f"Authorization of User `{self.user.username if self.user else 'no user'}` " - "for this request failed." - f"\nThe following requests were denied:\n{denied}." - if show_denials - else "" - ) - raise DJAuthorizationException( - errors=[ - DJError( - code=ErrorCode.UNAUTHORIZED_ACCESS, - message=message, - ), - ], - ) - - def validate(self) -> Set[ResourceRequest]: - """ - Checks with ACS and stores any returned invalid requests. - - Now synchronous - authorization works on pre-loaded user data. - """ - self.validation_request_count += 1 - - access_control = AccessControl( - user=self.user.username if self.user is not None else "", - state=self.state, - direct_requests=deepcopy(self.direct_requests), - indirect_requests=deepcopy(self.indirect_requests), - validation_request_count=self.validation_request_count, - session=self.session, # Deprecated - kept for backward compat - user_id=self.user.id if self.user is not None else None, # Deprecated - user_object=self.user_object, # Pass full User object - ) - - # Call validate_access (now sync!) - self.validate_access(access_control) # type: ignore - - self.validation_results = access_control.requests - - if any((result.approved is None for result in self.validation_results)): - raise DJAuthorizationException( - errors=[ - DJError( - code=ErrorCode.INCOMPLETE_AUTHORIZATION, - message="Injected `validate_access` must approve or deny all requests.", - ), - ], - ) - - return self.validation_results - - def validate_and_raise(self): - """ - Validates with ACS and raises if any resources were denied - """ - self.validate() - self.raise_if_invalid_requests() + request: ResourceRequest + approved: bool + reason: str | None = None diff --git a/datajunction-server/tests/api/access_test.py b/datajunction-server/tests/api/access_test.py index 6d7749c05..e02adcb52 100644 --- a/datajunction-server/tests/api/access_test.py +++ b/datajunction-server/tests/api/access_test.py @@ -2,14 +2,28 @@ Tests for the data API. """ +from http import HTTPStatus import pytest from httpx import AsyncClient -from datajunction_server.api.main import app -from datajunction_server.internal.access.authorization import validate_access +from datajunction_server.internal.access.authorization import AuthorizationService from datajunction_server.models import access +class DenyAllAuthorizationService(AuthorizationService): + """ + Custom authorization service that denies all access. + """ + + name = "deny_all" + + def authorize(self, auth_context, requests): + return [ + access.AccessDecision(request=request, approved=False) + for request in requests + ] + + class TestDataAccessControl: """ Test the data access control. @@ -19,43 +33,43 @@ class TestDataAccessControl: async def test_get_metric_data_unauthorized( self, module__client_with_examples: AsyncClient, + mocker, ) -> None: """ Test retrieving data for a metric """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - access_control.deny_all() - - return _validate_access + def get_deny_all_service(): + return DenyAllAuthorizationService() - app.dependency_overrides[validate_access] = validate_access_override + mocker.patch( + "datajunction_server.internal.access.authorization.get_authorization_service", + get_deny_all_service, + ) response = await module__client_with_examples.get("/data/basic.num_comments/") data = response.json() - assert "Authorization of User `dj` for this request failed" in data["message"] - assert "read:node/basic.num_comments" in data["message"] - assert "read:node/basic.source.comments" in data["message"] - assert response.status_code == 403 - app.dependency_overrides.clear() + assert "Access denied to" in data["message"] + assert "basic.num_comments" in data["message"] + assert response.status_code == HTTPStatus.FORBIDDEN @pytest.mark.asyncio async def test_sql_with_filters_orderby_no_access( self, module__client_with_examples: AsyncClient, + mocker, ): """ Test ``GET /sql/{node_name}/`` with various filters and dimensions using a version of the DJ roads database with namespaces. """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - access_control.deny_all() + def get_deny_all_service(): + return DenyAllAuthorizationService() - return _validate_access - - app.dependency_overrides[validate_access] = validate_access_override + mocker.patch( + "datajunction_server.internal.access.authorization.get_authorization_service", + get_deny_all_service, + ) node_name = "foo.bar.num_repair_orders" dimensions = [ @@ -76,13 +90,6 @@ def _validate_access(access_control: access.AccessControl): params={"dimensions": dimensions, "filters": filters, "orderby": orderby}, ) data = response.json() - assert sorted(list(data["message"])) == sorted( - list( - "Authorization of User `dj` for this request failed." - "\nThe following requests were denied:\nread:node/foo.bar.dispatcher, " - "read:node/foo.bar.repair_orders, read:node/foo.bar.municipality_dim, " - "read:node/foo.bar.num_repair_orders, read:node/foo.bar.hard_hat.", - ), - ) - assert data["errors"][0]["code"] == 500 - app.dependency_overrides.clear() + assert "Access denied to" in data["message"] + assert "foo.bar" in data["message"] + assert response.status_code == HTTPStatus.FORBIDDEN diff --git a/datajunction-server/tests/api/dimensions_access_test.py b/datajunction-server/tests/api/dimensions_access_test.py index f235be485..4d18207f3 100644 --- a/datajunction-server/tests/api/dimensions_access_test.py +++ b/datajunction-server/tests/api/dimensions_access_test.py @@ -5,30 +5,43 @@ import pytest from httpx import AsyncClient -from datajunction_server.api.main import app +from datajunction_server.internal.access.authorization import AuthorizationService +from datajunction_server.models import access + + +class RepairOnlyAuthorizationService(AuthorizationService): + """ + Authorization service that only approves nodes with 'repair' in the name. + """ + + name = "repair_only" + + def authorize(self, auth_context, requests): + return [ + access.AccessDecision( + request=request, + approved="repair" in request.access_object.name, + ) + for request in requests + ] @pytest.mark.asyncio async def test_list_nodes_with_dimension_access_limited( module__client_with_roads: AsyncClient, + mocker, ) -> None: """ Test ``GET /dimensions/{name}/nodes/``. """ - from datajunction_server.internal.access.authorization import validate_access - from datajunction_server.models import access - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - for request in access_control.requests: - if "repair" in request.access_object.name: - request.approve() - else: - request.deny() + def get_repair_only_service(): + return RepairOnlyAuthorizationService() - return _validate_access - - app.dependency_overrides[validate_access] = validate_access_override + mocker.patch( + "datajunction_server.internal.access.authorization.get_authorization_service", + get_repair_only_service, + ) response = await module__client_with_roads.get( "/dimensions/default.hard_hat/nodes/", @@ -47,4 +60,3 @@ def _validate_access(access_control: access.AccessControl): "default.avg_repair_order_discounts", } assert {node["name"] for node in data} == roads_repair_nodes - app.dependency_overrides.clear() diff --git a/datajunction-server/tests/api/namespaces_test.py b/datajunction-server/tests/api/namespaces_test.py index 6be0ee497..0450f9a77 100644 --- a/datajunction-server/tests/api/namespaces_test.py +++ b/datajunction-server/tests/api/namespaces_test.py @@ -8,8 +8,9 @@ import pytest from httpx import AsyncClient -from datajunction_server.api.main import app -from datajunction_server.internal.access.authorization import validate_access +from datajunction_server.internal.access.authorization import ( + AuthorizationService, +) from datajunction_server.models import access @@ -831,28 +832,42 @@ async def test_export_namespaces_deployment(client_with_roads: AsyncClient): ] +class DbtOnlyAuthorizationService(AuthorizationService): + """ + Authorization service that only approves namespaces containing 'dbt'. + """ + + name = "dbt_only" + + def authorize(self, auth_context, requests): + return [ + access.AccessDecision( + request=request, + approved=( + request.access_object.resource_type == access.ResourceType.NAMESPACE + and "dbt" in request.access_object.name + ), + ) + for request in requests + ] + + @pytest.mark.asyncio async def test_list_all_namespaces_access_limited( client_with_dbt: AsyncClient, + mocker, ) -> None: """ Test ``GET /namespaces/``. """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - for request in access_control.requests: - if ( - request.access_object.resource_type == access.ResourceType.NAMESPACE - and "dbt" in request.access_object.name - ): - request.approve() - else: - request.deny() + def get_dbt_only_service(): + return DbtOnlyAuthorizationService() - return _validate_access - - app.dependency_overrides[validate_access] = validate_access_override + mocker.patch( + "datajunction_server.internal.access.authorization.get_authorization_service", + get_dbt_only_service, + ) response = await client_with_dbt.get("/namespaces/") @@ -864,63 +879,42 @@ def _validate_access(access_control: access.AccessControl): {"namespace": "dbt.source.stripe", "num_nodes": 1}, {"namespace": "dbt.transform", "num_nodes": 1}, ] - app.dependency_overrides.clear() -@pytest.mark.asyncio -async def test_list_all_namespaces_access_bad_injection( - client_with_service_setup: AsyncClient, -) -> None: +class DenyAllAuthorizationService(AuthorizationService): """ - Test ``GET /namespaces/``. + Authorization service that denies all access requests. """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - for i, request in enumerate(access_control.requests): - if i != 0: - request.approve() - - return _validate_access + name = "deny_all" - app.dependency_overrides[validate_access] = validate_access_override - - response = await client_with_service_setup.get("/namespaces/") - - assert response.status_code == 403 - assert response.json() == { - "message": "Injected `validate_access` must approve or deny all requests.", - "errors": [ - { - "code": 501, - "message": "Injected `validate_access` must approve or deny all requests.", - "debug": None, - "context": "", - }, - ], - "warnings": [], - } - app.dependency_overrides.clear() + def authorize(self, auth_context, requests): + return [ + access.AccessDecision( + request=request, + approved=False, + ) + for request in requests + ] @pytest.mark.asyncio async def test_list_all_namespaces_deny_all( client_with_service_setup: AsyncClient, + mocker, ) -> None: """ Test ``GET /namespaces/``. """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - access_control.deny_all() - - return _validate_access - - app.dependency_overrides[validate_access] = validate_access_override + def get_deny_all_service(): + return DenyAllAuthorizationService() + mocker.patch( + "datajunction_server.internal.access.authorization.get_authorization_service", + get_deny_all_service, + ) response = await client_with_service_setup.get("/namespaces/") assert response.status_code in (200, 201) assert response.json() == [] - app.dependency_overrides.clear() diff --git a/datajunction-server/tests/api/sql_test.py b/datajunction-server/tests/api/sql_test.py index a1063e25b..1ff8923b8 100644 --- a/datajunction-server/tests/api/sql_test.py +++ b/datajunction-server/tests/api/sql_test.py @@ -11,7 +11,7 @@ from datajunction_server.database.node import Node, NodeRevision from datajunction_server.database.queryrequest import QueryBuildType, QueryRequest from datajunction_server.database.user import User -from datajunction_server.internal.access.authorization import validate_access +from datajunction_server.internal.access.authorization import AuthorizationService from datajunction_server.models import access from datajunction_server.models.node_type import NodeType from datajunction_server.sql.parsing.backends.antlr4 import parse @@ -2726,24 +2726,31 @@ async def test_get_sql_for_metrics_failures(module__client_with_examples: AsyncC @pytest.mark.asyncio -async def test_get_sql_for_metrics_no_access(module__client_with_examples: AsyncClient): +async def test_get_sql_for_metrics_no_access( + module__client_with_examples: AsyncClient, + mocker, +): """ - Test getting sql for multiple metrics. + Test getting sql for multiple metrics with denied access. """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - if access_control.state == "direct": - access_control.approve_all() - else: - access_control.deny_all() + # Custom authorization service that denies all requests + class DenyAllAuthorizationService(AuthorizationService): + name = "deny_all" - return _validate_access + def authorize(self, auth_context, requests): + return [ + access.AccessDecision(request=request, approved=False) + for request in requests + ] - module__client_with_examples.app.dependency_overrides[validate_access] = ( - validate_access_override - ) + def get_deny_all_service(): + return DenyAllAuthorizationService() + mocker.patch( + "datajunction_server.internal.access.authorization.get_authorization_service", + get_deny_all_service, + ) response = await module__client_with_examples.get( "/sql/", params={ @@ -2762,20 +2769,12 @@ def _validate_access(access_control: access.AccessControl): }, ) data = response.json() - # assert "Authorization of User `dj` for this request failed.\n" in data["message"] - assert "The following requests were denied:\n" in data["message"] - assert "read:node/default.municipality_dim" in data["message"] - assert "read:node/default.dispatcher" in data["message"] - assert "read:node/default.repair_orders_fact" in data["message"] - assert "read:node/default.hard_hat" in data["message"] - assert data["errors"][0]["code"] == 500 - - module__client_with_examples.app.dependency_overrides[validate_access] = ( - validate_access + assert data["message"] == ( + "Access denied to 10 resource(s): default.discounted_orders_rate, " + "default.discounted_orders_rate, default.num_repair_orders, " + "default.repair_orders_fact, default.hard_hat and 5 more" ) - module__client_with_examples.app.dependency_overrides.clear() - @pytest.mark.asyncio async def test_get_sql_for_metrics2(client_with_examples: AsyncClient): @@ -3369,31 +3368,39 @@ async def test_get_sql_for_metrics_orderby_not_in_dimensions( @pytest.mark.asyncio async def test_get_sql_for_metrics_orderby_not_in_dimensions_no_access( module__client_with_examples: AsyncClient, + mocker, ): """ Test that we extract the columns from filters to validate that they are from shared dimensions """ - def validate_access_override(): - def _validate_access(access_control: access.AccessControl): - for request in access_control.requests: + # Custom authorization service that denies specific nodes + class SelectiveDenialAuthorizationService(AuthorizationService): + name = "selective_denial" + + def authorize(self, auth_context, requests): + denied_nodes = { + "foo.bar.avg_repair_price", + "default.hard_hat.city", + } + return [ + access.AccessDecision(request=request, approved=False) if ( request.access_object.resource_type == access.ResourceType.NODE - and request.access_object.name - in ( - "foo.bar.avg_repair_price", - "default.hard_hat.city", - ) - ): - request.deny() - else: - request.approve() - - return _validate_access - - module__client_with_examples.app.dependency_overrides[validate_access] = ( - validate_access_override + and request.access_object.name in denied_nodes + ) + else access.AccessDecision(request=request, approved=True) + for request in requests + ] + + def get_selective_denial_service(): + return SelectiveDenialAuthorizationService() + + mocker.patch( + "datajunction_server.internal.access.authorization.get_authorization_service", + return_value=SelectiveDenialAuthorizationService(), ) + response = await module__client_with_examples.get( "/sql/", params={ @@ -3410,7 +3417,6 @@ def _validate_access(access_control: access.AccessControl): "Columns ['default.hard_hat.city'] in order by " "clause must also be specified in the metrics or dimensions" ) - module__client_with_examples.app.dependency_overrides.clear() @pytest.mark.asyncio diff --git a/datajunction-server/tests/conftest.py b/datajunction-server/tests/conftest.py index e5ad2d0ca..fdd216967 100644 --- a/datajunction-server/tests/conftest.py +++ b/datajunction-server/tests/conftest.py @@ -55,8 +55,10 @@ from datajunction_server.database.engine import Engine from datajunction_server.database.user import User from datajunction_server.errors import DJQueryServiceClientEntityNotFound -from datajunction_server.internal.access.authorization import validate_access -from datajunction_server.models.access import AccessControl, ValidateAccessFn +from datajunction_server.internal.access.authorization import ( + get_authorization_service, + PassthroughAuthorizationService, +) from datajunction_server.models.materialization import MaterializationInfo from datajunction_server.models.query import QueryCreate, QueryWithResults from datajunction_server.models.user import OAuthProvider @@ -523,16 +525,14 @@ def get_session_override() -> AsyncSession: def get_settings_override() -> Settings: return settings_no_qs - def default_validate_access() -> ValidateAccessFn: - def _(access_control: AccessControl): - access_control.approve_all() - - return _ + def get_passthrough_auth_service(): + """Override to approve all requests in tests.""" + return PassthroughAuthorizationService() if use_patch: app.dependency_overrides[get_session] = get_session_override app.dependency_overrides[get_settings] = get_settings_override - app.dependency_overrides[validate_access] = default_validate_access + app.dependency_overrides[get_authorization_service] = get_passthrough_auth_service async with AsyncClient( transport=httpx.ASGITransport(app=app), @@ -805,18 +805,16 @@ def get_query_service_client_override( def get_settings_override() -> Settings: return settings - def default_validate_access() -> ValidateAccessFn: - def _(access_control: AccessControl): - access_control.approve_all() - - return _ + def get_passthrough_auth_service(): + """Override to approve all requests in tests.""" + return PassthroughAuthorizationService() def get_session_override() -> AsyncSession: return session app.dependency_overrides[get_session] = get_session_override app.dependency_overrides[get_settings] = get_settings_override - app.dependency_overrides[validate_access] = default_validate_access + app.dependency_overrides[get_authorization_service] = get_passthrough_auth_service app.dependency_overrides[get_query_service_client] = ( get_query_service_client_override ) @@ -984,11 +982,9 @@ def get_session_override() -> AsyncSession: def get_settings_override() -> Settings: return module__settings - def default_validate_access() -> ValidateAccessFn: - def _(access_control: AccessControl): - access_control.approve_all() - - return _ + def get_passthrough_auth_service(): + """Override to approve all requests in tests.""" + return PassthroughAuthorizationService() module_mocker.patch( "datajunction_server.api.materializations.get_query_service_client", @@ -997,7 +993,7 @@ def _(access_control: AccessControl): app.dependency_overrides[get_session] = get_session_override app.dependency_overrides[get_settings] = get_settings_override - app.dependency_overrides[validate_access] = default_validate_access + app.dependency_overrides[get_authorization_service] = get_passthrough_auth_service app.dependency_overrides[get_query_service_client] = ( get_query_service_client_override ) diff --git a/datajunction-server/tests/internal/authorization_test.py b/datajunction-server/tests/internal/authorization_test.py index e4b4d3fcf..85be5092c 100644 --- a/datajunction-server/tests/internal/authorization_test.py +++ b/datajunction-server/tests/internal/authorization_test.py @@ -9,11 +9,11 @@ from datajunction_server.database.rbac import Role, RoleAssignment, RoleScope from datajunction_server.database.user import PrincipalKind, User from datajunction_server.internal.access.authorization import ( + AccessChecker, AccessDenialMode, AuthContext, PassthroughAuthorizationService, RBACAuthorizationService, - authorize, get_authorization_service, ) from datajunction_server.errors import DJAuthorizationException @@ -426,7 +426,6 @@ async def test_passthrough_service_approves_all( access_object=Resource( name="finance.revenue", resource_type=ResourceType.NAMESPACE, - owner="", ), ), ResourceRequest( @@ -434,7 +433,6 @@ async def test_passthrough_service_approves_all( access_object=Resource( name="secret.data", resource_type=ResourceType.NODE, - owner="", ), ), ] @@ -484,7 +482,6 @@ async def test_rbac_service_with_permissions( access_object=Resource( name="finance.revenue", resource_type=ResourceType.NAMESPACE, - owner="", ), ), ResourceRequest( @@ -492,17 +489,15 @@ async def test_rbac_service_with_permissions( access_object=Resource( name="finance.revenue", resource_type=ResourceType.NAMESPACE, - owner="", ), ), ] - result = await authorize( - session, - user, - requests, - on_denied=AccessDenialMode.RETURN, + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), ) + access_checker.add_requests(requests) + result = await access_checker.check(on_denied=AccessDenialMode.RETURN) assert len(result) == 2 assert result[0].approved is True # READ granted assert result[1].approved is False # WRITE not granted @@ -601,16 +596,16 @@ async def test_user_inherits_group_permissions( mock_settings.default_access_policy = "restrictive" # Check permission - should be granted via group - results = await authorize( - session=session, - user=user, - resource_requests=[ + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_requests( + [ ResourceRequest( verb=ResourceAction.READ, access_object=Resource( name="finance.revenue.something", resource_type=ResourceType.NODE, - owner="", ), ), ResourceRequest( @@ -618,11 +613,11 @@ async def test_user_inherits_group_permissions( access_object=Resource( resource_type=ResourceType.NAMESPACE, name="finance.revenue", - owner="", ), ), ], ) + results = await access_checker.check(on_denied=AccessDenialMode.RETURN) assert results[0].approved is True assert results[1].approved is True @@ -674,21 +669,21 @@ async def test_user_no_permission_without_group( mock_settings.default_access_policy = "restrictive" # Check permission - should NOT be granted (user not in group) - results = await authorize( - session=session, - user=user, - resource_requests=[ + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_requests( + [ ResourceRequest( verb=ResourceAction.READ, access_object=Resource( name="marketing.revenue", resource_type=ResourceType.NAMESPACE, - owner="", ), ), ], - on_denied=AccessDenialMode.RETURN, ) + results = await access_checker.check(on_denied=AccessDenialMode.RETURN) assert results[0].approved is False @@ -1459,7 +1454,6 @@ async def test_check_access_filter_mode_returns_only_approved( access_object=Resource( name="finance.revenue", resource_type=ResourceType.NODE, - owner="", ), ), ResourceRequest( @@ -1467,7 +1461,6 @@ async def test_check_access_filter_mode_returns_only_approved( access_object=Resource( name="finance.cost", resource_type=ResourceType.NODE, - owner="", ), ), ResourceRequest( @@ -1475,7 +1468,6 @@ async def test_check_access_filter_mode_returns_only_approved( access_object=Resource( name="marketing.revenue", resource_type=ResourceType.NODE, - owner="", ), ), ] @@ -1487,18 +1479,15 @@ async def test_check_access_filter_mode_returns_only_approved( mock_settings.default_access_policy = "restrictive" # Check access (default FILTER mode) - approved = await authorize( - session, - user, - requests, - on_denied=AccessDenialMode.FILTER, + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), ) + access_checker.add_requests(requests) + approved = await access_checker.approved_resource_names() # Should only return the 2 approved (finance.* nodes) assert len(approved) == 2 - assert all(req.approved for req in approved) - approved_names = {req.access_object.name for req in approved} - assert approved_names == {"finance.revenue", "finance.cost"} + assert approved == {"finance.revenue", "finance.cost"} async def test_check_access_raise_mode_throws_on_denial( self, @@ -1517,7 +1506,6 @@ async def test_check_access_raise_mode_throws_on_denial( access_object=Resource( name="finance.revenue", resource_type=ResourceType.NODE, - owner="", ), ) @@ -1528,12 +1516,11 @@ async def test_check_access_raise_mode_throws_on_denial( mock_settings.default_access_policy = "restrictive" with pytest.raises(DJAuthorizationException) as exc_info: - await authorize( - session, - user, - [request], - on_denied=AccessDenialMode.RAISE, + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), ) + access_checker.add_request(request) + await access_checker.check(on_denied=AccessDenialMode.RAISE) # Check exception message assert "Access denied" in str(exc_info.value) @@ -1574,18 +1561,15 @@ async def test_check_access_raise_mode_succeeds_when_approved( access_object=Resource( name="finance.revenue", resource_type=ResourceType.NODE, - owner="", ), ) # Should NOT raise - result = await authorize( - session, - user, - [request], - on_denied=AccessDenialMode.RAISE, + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), ) - + access_checker.add_request(request) + result = await access_checker.check(on_denied=AccessDenialMode.RAISE) assert len(result) == 1 assert result[0].approved is True @@ -1626,7 +1610,6 @@ async def test_check_access_return_mode( access_object=Resource( name="finance.revenue", resource_type=ResourceType.NODE, - owner="", ), ), ResourceRequest( @@ -1634,7 +1617,6 @@ async def test_check_access_return_mode( access_object=Resource( name="finance.cost", resource_type=ResourceType.NODE, - owner="", ), ), ResourceRequest( @@ -1642,7 +1624,6 @@ async def test_check_access_return_mode( access_object=Resource( name="marketing.revenue", resource_type=ResourceType.NODE, - owner="", ), ), ] @@ -1654,12 +1635,12 @@ async def test_check_access_return_mode( mock_settings.default_access_policy = "restrictive" # Check access with RETURN_ALL - all_requests = await authorize( - session, - user, - requests, - on_denied=AccessDenialMode.RETURN, + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), ) + access_checker.add_requests(requests) + + all_requests = await access_checker.check(on_denied=AccessDenialMode.RETURN) # Should return all 3 requests assert len(all_requests) == 3 @@ -1670,7 +1651,7 @@ async def test_check_access_return_mode( assert len(approved) == 2 assert len(denied) == 1 - assert denied[0].access_object.name == "marketing.revenue" + assert denied[0].request.access_object.name == "marketing.revenue" @pytest.mark.asyncio @@ -1889,12 +1870,15 @@ async def test_check_access_with_group_based_permissions( access_object=Resource( name="data.user_events", resource_type=ResourceType.NODE, - owner="", ), ) # Should be approved via group - approved = await authorize(session, user, [request]) + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), + ) + access_checker.add_request(request) + approved = await access_checker.check(on_denied=AccessDenialMode.RETURN) assert len(approved) == 1 assert approved[0].approved is True @@ -1936,7 +1920,6 @@ async def test_check_access_with_mixed_approval( access_object=Resource( name="finance.revenue", resource_type=ResourceType.NODE, - owner="", ), ), ResourceRequest( @@ -1944,7 +1927,6 @@ async def test_check_access_with_mixed_approval( access_object=Resource( name="marketing.revenue", resource_type=ResourceType.NODE, - owner="", ), ), ] @@ -1955,31 +1937,20 @@ async def test_check_access_with_mixed_approval( mock_settings.default_access_policy = "restrictive" # FILTER mode - returns only approved - filtered = await authorize( - session, - user, - requests, - on_denied=AccessDenialMode.FILTER, + access_checker = AccessChecker( + auth_context=await AuthContext.from_user(user=user, session=session), ) + access_checker.add_requests(requests) + filtered = await access_checker.check(on_denied=AccessDenialMode.FILTER) assert len(filtered) == 1 - assert filtered[0].access_object.name == "finance.revenue" + assert filtered[0].request.access_object.name == "finance.revenue" # RETURN_ALL mode - returns both - all_results = await authorize( - session, - user, - requests, - on_denied=AccessDenialMode.RETURN, - ) + all_results = await access_checker.check(on_denied=AccessDenialMode.RETURN) assert len(all_results) == 2 assert all_results[0].approved is True assert all_results[1].approved is False # RAISE mode - should raise with pytest.raises(DJAuthorizationException): - await authorize( - session, - user, - requests, - on_denied=AccessDenialMode.RAISE, - ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) diff --git a/datajunction-server/tests/internal/deployment/orchestration_test.py b/datajunction-server/tests/internal/deployment/orchestration_test.py index 1cf30eac7..dd8e8fe31 100644 --- a/datajunction-server/tests/internal/deployment/orchestration_test.py +++ b/datajunction-server/tests/internal/deployment/orchestration_test.py @@ -42,7 +42,6 @@ def mock_deployment_context(current_user: User): context.current_user = current_user context.request = Mock() context.query_service_client = Mock() - context.validate_access = AsyncMock(return_value=True) context.background_tasks = Mock() context.save_history = AsyncMock() context.cache = Mock() From d0252cdebced7ea70eb7456abc48bf0256a52990 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Mon, 1 Dec 2025 23:42:10 -0800 Subject: [PATCH 3/4] Refactor authorization logic into multiple modules + fix authorization tests --- .../datajunction_server/api/dimensions.py | 7 +- .../datajunction_server/api/graphql/main.py | 2 + .../datajunction_server/api/namespaces.py | 45 +- .../datajunction_server/api/nodes.py | 186 ++++- .../datajunction_server/config.py | 4 +- .../construction/build_v2.py | 6 +- .../internal/access/authorization.py | 636 ------------------ .../internal/access/authorization/__init__.py | 31 + .../internal/access/authorization/context.py | 130 ++++ .../internal/access/authorization/service.py | 329 +++++++++ .../access/authorization/validator.py | 164 +++++ .../internal/caching/query_cache_manager.py | 1 - .../datajunction_server/internal/nodes.py | 15 - .../datajunction_server/internal/sql.py | 2 +- datajunction-server/tests/api/access_test.py | 156 ++++- .../tests/api/dimensions_access_test.py | 2 +- .../api/graphql/common_dimensions_test.py | 2 +- .../tests/api/namespaces_test.py | 4 +- datajunction-server/tests/api/sql_test.py | 4 +- .../tests/construction/build_test.py | 57 +- .../tests/internal/authorization_test.py | 76 ++- .../caching/query_cache_manager_test.py | 116 ++-- .../tests/models/access_test.py | 128 ++++ 23 files changed, 1366 insertions(+), 737 deletions(-) delete mode 100644 datajunction-server/datajunction_server/internal/access/authorization.py create mode 100644 datajunction-server/datajunction_server/internal/access/authorization/__init__.py create mode 100644 datajunction-server/datajunction_server/internal/access/authorization/context.py create mode 100644 datajunction-server/datajunction_server/internal/access/authorization/service.py create mode 100644 datajunction-server/datajunction_server/internal/access/authorization/validator.py create mode 100644 datajunction-server/tests/models/access_test.py diff --git a/datajunction-server/datajunction_server/api/dimensions.py b/datajunction-server/datajunction_server/api/dimensions.py index 0bb8a341e..2e8b72152 100644 --- a/datajunction-server/datajunction_server/api/dimensions.py +++ b/datajunction-server/datajunction_server/api/dimensions.py @@ -3,7 +3,7 @@ """ import logging -from typing import List, Optional +from typing import List, Optional, cast from fastapi import Depends, Query from sqlalchemy.ext.asyncio import AsyncSession @@ -71,7 +71,10 @@ async def find_nodes_with_dimension( """ List all nodes that have the specified dimension """ - dimension_node = await Node.get_by_name(session, name) + dimension_node = cast( + Node, + await Node.get_by_name(session, name, raise_if_not_exists=True), + ) access_checker.add_node(dimension_node, access.ResourceAction.READ) nodes = await get_nodes_with_common_dimensions( diff --git a/datajunction-server/datajunction_server/api/graphql/main.py b/datajunction-server/datajunction_server/api/graphql/main.py index 088ba548b..fa3b0f345 100644 --- a/datajunction-server/datajunction_server/api/graphql/main.py +++ b/datajunction-server/datajunction_server/api/graphql/main.py @@ -9,6 +9,7 @@ from strawberry.types import Info from datajunction_server.internal.caching.cachelib_cache import get_cache +from datajunction_server.internal.access.authentication.http import DJHTTPBearer from datajunction_server.api.graphql.queries.catalogs import list_catalogs from datajunction_server.api.graphql.queries.dag import ( common_dimensions, @@ -82,6 +83,7 @@ async def get_context( background_tasks: BackgroundTasks, db_session=Depends(get_session), cache=Depends(get_cache), + _auth=Depends(DJHTTPBearer(auto_error=False)), ): """ Provides the context for graphql requests diff --git a/datajunction-server/datajunction_server/api/namespaces.py b/datajunction-server/datajunction_server/api/namespaces.py index 6aeba17cf..d2011af76 100644 --- a/datajunction-server/datajunction_server/api/namespaces.py +++ b/datajunction-server/datajunction_server/api/namespaces.py @@ -15,12 +15,14 @@ from datajunction_server.database.node import Node from datajunction_server.database.user import User from datajunction_server.errors import DJAlreadyExistsException +from datajunction_server.models.access import ResourceAction from datajunction_server.models.deployment import CubeSpec, DeploymentSpec from datajunction_server.models.dimensionlink import LinkType from datajunction_server.internal.access.authentication.http import SecureAPIRouter from datajunction_server.internal.access.authorization import ( AccessChecker, get_access_checker, + AccessDenialMode, ) from datajunction_server.internal.namespaces import ( create_namespace, @@ -143,17 +145,38 @@ async def list_nodes_in_namespace( description="Whether to include a list of users who edited each node", ), session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeMinimumDetail]: """ List node names in namespace, filterable to a given type if desired. """ - return await NodeNamespace.list_nodes( + # Check that the user has namespace-level READ access + access_checker.add_namespace(namespace, access.ResourceAction.READ) + namespace_decisions = await access_checker.check( + on_denied=AccessDenialMode.FILTER, + ) + if not namespace_decisions: + # User has no access to this namespace at all + return [] + + # Get all nodes in namespace + nodes = await NodeNamespace.list_nodes( session, namespace, type_, with_edited_by=with_edited_by, ) + # Filter to nodes the user has READ access to + access_checker.add_nodes(nodes=nodes, action=access.ResourceAction.READ) + node_decisions = await access_checker.check(on_denied=AccessDenialMode.RETURN) + approved_names = { + decision.request.access_object.name + for decision in node_decisions + if decision.approved + } + return [node for node in nodes if node.name in approved_names] + @router.delete("/namespaces/{namespace}/", status_code=HTTPStatus.OK) async def deactivate_a_namespace( @@ -169,10 +192,14 @@ async def deactivate_a_namespace( query_service_client: QueryServiceClient = Depends(get_query_service_client), background_tasks: BackgroundTasks, request: Request, + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Deactivates a node namespace """ + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node_namespace = await NodeNamespace.get( session, namespace, @@ -253,10 +280,14 @@ async def restore_a_namespace( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Restores a node namespace """ + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node_namespace = await get_node_namespace( session=session, namespace=namespace, @@ -327,6 +358,7 @@ async def hard_delete_node_namespace( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Hard delete a namespace, which will completely remove the namespace. Additionally, @@ -334,6 +366,9 @@ async def hard_delete_node_namespace( is set to true. If cascade is set to false, we'll raise an error. This should be used with caution, as the impact may be large. """ + access_checker.add_namespace(namespace, ResourceAction.DELETE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + impacts = await hard_delete_namespace( session=session, namespace=namespace, @@ -358,11 +393,15 @@ async def export_a_namespace( namespace: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[Dict]: """ Generates a zip of YAML files for the contents of the given namespace as well as a project definition file. """ + access_checker.add_namespace(namespace, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + return await get_project_config( session=session, nodes=await get_nodes_in_namespace_detailed(session, namespace), @@ -387,10 +426,14 @@ async def export_namespace_spec( namespace: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> DeploymentSpec: """ Generates a deployment spec for a namespace """ + access_checker.add_namespace(namespace, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + nodes = await NodeNamespace.list_all_nodes( session, namespace, diff --git a/datajunction-server/datajunction_server/api/nodes.py b/datajunction-server/datajunction_server/api/nodes.py index 0d0effbf2..343867416 100644 --- a/datajunction-server/datajunction_server/api/nodes.py +++ b/datajunction-server/datajunction_server/api/nodes.py @@ -45,7 +45,9 @@ from datajunction_server.internal.access.authorization import ( AccessChecker, get_access_checker, + AccessDenialMode, ) +from datajunction_server.models.access import ResourceAction from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.internal.nodes import ( activate_node, @@ -160,10 +162,14 @@ async def revalidate( save_history: Callable = Depends(get_save_history), *, background_tasks: BackgroundTasks, + access_checker: AccessChecker = Depends(get_access_checker), ) -> NodeStatusDetails: """ Revalidate a single existing node and update its status appropriately """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node_validator = await revalidate_node( name=name, session=session, @@ -210,10 +216,14 @@ async def set_column_attributes( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[ColumnOutput]: """ Set column attributes for the node. """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, node_name, @@ -284,16 +294,18 @@ async def list_all_nodes_with_details( "%s limit reached when returning all nodes, all nodes may not be captured in results", settings.node_list_max, ) - for row in results: - access_checker.add_request( + access_checker.add_requests( + [ access.ResourceRequest( verb=access.ResourceAction.READ, access_object=access.Resource( name=row.name, resource_type=access.ResourceType.NODE, ), - ), - ) + ) + for row in results + ], + ) approvals = await access_checker.approved_resource_names() return [row for row in results if row.name in approvals] @@ -303,10 +315,14 @@ async def get_node( name: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> NodeOutput: """ Show the active version of the specified node. """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, name, @@ -326,10 +342,14 @@ async def delete_node( query_service_client: QueryServiceClient = Depends(get_query_service_client), background_tasks: BackgroundTasks, request: Request, + access_checker: AccessChecker = Depends(get_access_checker), ): """ Delete (aka deactivate) the specified node. """ + access_checker.add_request_by_node_name(name, ResourceAction.DELETE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + await deactivate_node( session=session, name=name, @@ -351,11 +371,15 @@ async def hard_delete( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Hard delete a node, destroying all links and invalidating all downstream nodes. This should be used with caution, deactivating a node is preferred. """ + access_checker.add_request_by_node_name(name, ResourceAction.DELETE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + impact = await hard_delete_node( name=name, session=session, @@ -378,10 +402,14 @@ async def restore_node( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ): """ Restore (aka re-activate) the specified node. """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + await activate_node( session=session, name=name, @@ -399,10 +427,14 @@ async def list_node_revisions( name: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[NodeRevisionOutput]: """ List all revisions for the node. """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, name, @@ -430,6 +462,10 @@ async def create_source( Create a source node. If columns are not provided, the source node's schema will be inferred using the configured query service. """ + namespace = data.namespace or data.name.rsplit(".", 1)[0] + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + return await create_a_source_node( data=data, request=request, @@ -476,6 +512,11 @@ async def create_node( Create a node. """ node_type = NodeType(os.path.basename(os.path.normpath(request.url.path))) + + namespace = data.namespace or data.name.rsplit(".", 1)[0] + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + return await create_a_node( data=data, request=request, @@ -510,6 +551,22 @@ async def create_cube( """ Create a cube node. """ + # Check WRITE access on the namespace for creating the cube + namespace = data.namespace or data.name.rsplit(".", 1)[0] + access_checker.add_namespace(namespace, ResourceAction.WRITE) + + # Check READ access on all metrics and dimensions being included in the cube + if data.metrics: + for metric_name in data.metrics: + access_checker.add_request_by_node_name(metric_name, ResourceAction.READ) + if data.dimensions: + for dim_attr in data.dimensions: + # Dimension attributes are in format "node_name.column_name" + dim_node_name = dim_attr.rsplit(".", 1)[0] + access_checker.add_request_by_node_name(dim_node_name, ResourceAction.READ) + + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await create_a_cube( data=data, request=request, @@ -545,6 +602,7 @@ async def register_table( current_user: User = Depends(get_current_user), background_tasks: BackgroundTasks, save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> NodeOutput: """ Register a table. This creates a source node in the SOURCE_NODE_NAMESPACE and @@ -572,6 +630,8 @@ async def register_table( current_user=current_user, save_history=save_history, ) + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) # Use reflection to get column names and types _catalog = await get_catalog_by_name(session=session, name=catalog) @@ -598,6 +658,7 @@ async def register_table( background_tasks=background_tasks, save_history=save_history, request=request, + access_checker=access_checker, ) @@ -619,6 +680,7 @@ async def register_view( current_user: User = Depends(get_current_user), background_tasks: BackgroundTasks, save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> NodeOutput: """ Register a view by creating the view in the database and adding a source node for it. @@ -636,6 +698,9 @@ async def register_view( view_name = f"{schema_}.{view}" await raise_if_node_exists(session, node_name) + access_checker.add_namespace(namespace, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + # Re-create the view in the database _catalog = await get_catalog_by_name(session=session, name=catalog) or_replace = "OR REPLACE" if replace else "" @@ -687,6 +752,7 @@ async def register_view( background_tasks=background_tasks, save_history=save_history, request=request, + access_checker=access_checker, ) @@ -699,6 +765,7 @@ async def link_dimension( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Add a simple dimension link from a node column to a dimension node. @@ -706,6 +773,10 @@ async def link_dimension( 2. If no `dimension_column` is provided, the primary key column of the dimension node will be used as the join column for the link. """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + access_checker.add_request_by_node_name(dimension, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + activity_type = await upsert_simple_dimension_link( session, name, @@ -741,10 +812,15 @@ async def add_reference_dimension_link( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Add reference dimension link to a node column """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + access_checker.add_request_by_node_name(dimension_node, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + await upsert_reference_dimension_link( session=session, node_name=node_name, @@ -773,10 +849,14 @@ async def remove_reference_dimension_link( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Remove reference dimension link from a node column """ + access_checker.add_request_by_node_name(node_name, ResourceAction.DELETE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name(session, node_name, raise_if_not_exists=True) target_column = await get_column(session, node.current, node_column) # type: ignore if target_column.dimension_id or target_column.dimension_column: @@ -823,11 +903,19 @@ async def add_complex_dimension_link( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Links a source, dimension, or transform node to a dimension with a custom join query. If a link already exists, updates the link definition. """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + access_checker.add_request_by_node_name( + link_input.dimension_node, + ResourceAction.READ, + ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + activity_type = await upsert_complex_dimension_link( session, node_name, @@ -858,10 +946,16 @@ async def remove_complex_dimension_link( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Removes a complex dimension link based on the dimension node and its role (if any). """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + access_checker.add_request_by_node_name( + link_identifier.dimension_node, + ResourceAction.READ, + ) return await remove_dimension_link( session, node_name, @@ -880,10 +974,15 @@ async def delete_dimension_link( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Remove the link between a node column and a dimension node """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + access_checker.add_request_by_node_name(dimension, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + return await remove_dimension_link( session, name, @@ -906,10 +1005,14 @@ async def tags_node( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Add a tag to a node """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name(session=session, name=name) existing_tags = {tag.name for tag in node.tags} # type: ignore if not tag_names: @@ -960,10 +1063,14 @@ async def refresh_source_node( query_service_client: QueryServiceClient = Depends(get_query_service_client), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> NodeOutput: """ Refresh a source node with the latest columns from the query service. """ + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + return await refresh_source( # type: ignore name=name, session=session, @@ -992,6 +1099,22 @@ async def update_node( """ Update a node. """ + # Check WRITE access on the node being updated + access_checker.add_request_by_node_name(name, ResourceAction.WRITE) + + # For cube updates: check READ access on any metrics/dimensions being added + # (user must have access to read nodes they're including in the cube) + if data.metrics: + for metric_name in data.metrics: + access_checker.add_request_by_node_name(metric_name, ResourceAction.READ) + if data.dimensions: + for dim_attr in data.dimensions: + # Dimension attributes are in format "node_name.column_name" + dim_node_name = dim_attr.rsplit(".", 1)[0] + access_checker.add_request_by_node_name(dim_node_name, ResourceAction.READ) + + await access_checker.check(on_denied=AccessDenialMode.RAISE) + request_headers = dict(request.headers) await update_any_node( name, @@ -1021,10 +1144,15 @@ async def calculate_node_similarity( node2_name: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> JSONResponse: """ Compare two nodes by how similar their queries are """ + access_checker.add_request_by_node_name(node1_name, ResourceAction.READ) + access_checker.add_request_by_node_name(node2_name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node1 = await Node.get_by_name( session, node1_name, @@ -1059,12 +1187,16 @@ async def list_downstream_nodes( node_type: NodeType = None, depth: int = -1, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[DAGNodeOutput]: """ List all nodes that are downstream from the given node, filterable by type and max depth. Setting a max depth of -1 will include all downstream nodes. """ - return await get_downstream_nodes( + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + + downstreams = await get_downstream_nodes( session=session, node_name=name, node_type=node_type, @@ -1072,6 +1204,11 @@ async def list_downstream_nodes( depth=depth, ) + for node in downstreams: + access_checker.add_request_by_node_name(node.name, ResourceAction.READ) + accessible = await access_checker.approved_resource_names() + return [node for node in downstreams if node.name in accessible] + @router.get( "/nodes/{name}/upstream/", @@ -1085,10 +1222,14 @@ async def list_upstream_nodes( cache: Cache = Depends(get_cache), background_tasks: BackgroundTasks, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[DAGNodeOutput]: """ List all nodes that are upstream from the given node, filterable by type. """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = cast(Node, await Node.get_by_name(session, name, raise_if_not_exists=True)) upstream_cache_key = node.upstream_cache_key() results = cache.get(upstream_cache_key) @@ -1100,7 +1241,11 @@ async def list_upstream_nodes( results, timeout=settings.query_cache_timeout, ) - return results + + for node in results: + access_checker.add_request_by_node_name(node.name, ResourceAction.READ) + accessible = await access_checker.approved_resource_names() + return [node for node in results if node.name in accessible] @router.get( @@ -1111,11 +1256,15 @@ async def list_node_dag( name: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[DAGNodeOutput]: """ List all nodes that are part of the DAG of the given node. This means getting all upstreams, downstreams, and linked dimension nodes. """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, name, @@ -1151,10 +1300,14 @@ async def list_all_dimension_attributes( *, depth: int = 30, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> list[DimensionAttributeOutput]: """ List all available dimension attributes for the given node. """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + dimensions = await get_dimension_attributes(session, name) filter_only_dimensions = await get_filter_only_dimensions(session, name) return dimensions + filter_only_dimensions @@ -1169,10 +1322,13 @@ async def column_lineage( name: str, *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> List[LineageColumn]: """ List column-level lineage of a node in a graph """ + access_checker.add_request_by_node_name(name, ResourceAction.READ) + await access_checker.check(on_denied=AccessDenialMode.RAISE) node = await Node.get_by_name( session, @@ -1207,10 +1363,14 @@ async def set_column_display_name( save_history: Callable = Depends(get_save_history), *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> ColumnOutput: """ Set column name for the node """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, node_name, @@ -1249,10 +1409,14 @@ async def set_column_description( save_history: Callable = Depends(get_save_history), *, session: AsyncSession = Depends(get_session), + access_checker: AccessChecker = Depends(get_access_checker), ) -> ColumnOutput: """ Set column description for the node """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, node_name, @@ -1292,10 +1456,14 @@ async def set_column_partition( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> ColumnOutput: """ Add or update partition columns for the specified node. """ + access_checker.add_request_by_node_name(node_name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + node = await Node.get_by_name( session, node_name, @@ -1362,6 +1530,7 @@ async def copy_node( session: AsyncSession = Depends(get_session), current_user: User = Depends(get_current_user), save_history: Callable = Depends(get_save_history), + access_checker: AccessChecker = Depends(get_access_checker), ) -> DAGNodeOutput: """ Copy this node to a new name. @@ -1370,6 +1539,11 @@ async def copy_node( new_node_namespace = ".".join(new_name.split(".")[:-1]) await get_node_namespace(session, new_node_namespace, raise_if_not_exists=True) + # Check that the user has access to read the existing node and write to the new namespace + access_checker.add_request_by_node_name(node_name, ResourceAction.READ) + access_checker.add_namespace(new_name, ResourceAction.WRITE) + await access_checker.check(on_denied=AccessDenialMode.RAISE) + # Check if there is already a node with the new name existing_new_node = await get_node_by_name( session, diff --git a/datajunction-server/datajunction_server/config.py b/datajunction-server/datajunction_server/config.py index 1f6d307a9..9e0e2bd73 100644 --- a/datajunction-server/datajunction_server/config.py +++ b/datajunction-server/datajunction_server/config.py @@ -165,8 +165,8 @@ class Settings(BaseSettings): # pragma: no cover authorization_provider: str = "rbac" # Default access policy when no explicit RBAC rule exists: - # - "permissive": Allow by default (OSS-friendly, lock down as needed) - # - "restrictive": Deny by default (Enterprise, explicitly grant access) + # - "permissive": Allow by default + # - "restrictive": Deny by default default_access_policy: str = "permissive" # or "restrictive" # Interval in seconds with which to expire caching of any indexes diff --git a/datajunction-server/datajunction_server/construction/build_v2.py b/datajunction-server/datajunction_server/construction/build_v2.py index 0fc8503b8..3ac4a1d39 100644 --- a/datajunction-server/datajunction_server/construction/build_v2.py +++ b/datajunction-server/datajunction_server/construction/build_v2.py @@ -1403,7 +1403,7 @@ async def build(self) -> ast.Query: async def validate_access(self): """Validates access""" - if self._access_checker: + if self._access_checker: # pragma: no cover await self._access_checker.check(on_denied=AccessDenialMode.RAISE) async def build_measures_queries(self): @@ -1481,14 +1481,14 @@ async def build_metric_agg( """ Build the metric's aggregate expression. """ - if self._access_checker: + if self._access_checker: # pragma: no cover self._access_checker.add_node(metric_node, access.ResourceAction.READ) # type: ignore metric_query_builder = await QueryBuilder.create(self.session, metric_node) if self._ignore_errors: metric_query_builder = ( # pragma: no cover metric_query_builder.ignore_errors() ) - if self._access_checker: + if self._access_checker: # pragma: no cover metric_query_builder = metric_query_builder.with_access_control( self._access_checker, ) diff --git a/datajunction-server/datajunction_server/internal/access/authorization.py b/datajunction-server/datajunction_server/internal/access/authorization.py deleted file mode 100644 index e51d188d4..000000000 --- a/datajunction-server/datajunction_server/internal/access/authorization.py +++ /dev/null @@ -1,636 +0,0 @@ -""" -Authorization related functionality using pluggable services. - -This module defines an abstract base class `AuthorizationService` for implementing different -authorization strategies. It includes built-in implementations such as -`RBACAuthorizationService` for role-based access control and `PassthroughAuthorizationService` -for permissive access. - -Example custom implementation: -```python -class CustomAuthService(AuthorizationService): - name = "custom" - - def authorize(self, auth_context, requests): - # Sync in-memory authorization logic - return requests -``` -""" - -from fastapi import Depends -from abc import ABC, abstractmethod -from dataclasses import dataclass -from datetime import datetime, timezone -from enum import Enum -from functools import lru_cache -from typing import List, Optional - -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload - -from datajunction_server.internal.access.group_membership import ( - GroupMembershipService, - get_group_membership_service, -) -from datajunction_server.database.node import Node -from datajunction_server.database.rbac import RoleAssignment -from datajunction_server.database.user import User -from datajunction_server.models.access import ( - AccessDecision, - Resource, - ResourceAction, - ResourceRequest, - ResourceType, -) -from datajunction_server.utils import ( - SEPARATOR, - get_current_user, - get_session, - get_settings, -) - -settings = get_settings() - - -# ============================================================================ -# Access Check Modes -# ============================================================================ - - -class AccessDenialMode(Enum): - """ - How to handle denied access requests. - """ - - FILTER = "filter" # Return only approved requests - RAISE = "raise" # Raise exception if any denied - RETURN = "return" # Return all requests with approved field set - - -# ============================================================================ -# Authorization Context -# ============================================================================ - - -@dataclass(frozen=True) -class AuthContext: - """ - Authorization context for a user. - - Contains all data needed to make authorization decisions, - pre-loaded and ready for fast in-memory checks. - - This separates authorization data from the full User model, - allowing for clean caching, testing, and type safety. - """ - - user_id: int - username: str - oauth_provider: Optional[str] - role_assignments: List[RoleAssignment] # Direct + groups, flattened - - @classmethod - async def from_user( - cls, - session: AsyncSession, - user: User, - group_membership_service: GroupMembershipService | None = None, - ) -> "AuthContext": - """ - Build authorization context from a User object. - - This loads all effective role assignments (direct + group-based) - using the configured GroupMembershipService. - - Args: - session: Database session - user: User to build context for - group_membership_service: Optional service override - - Returns: - AuthContext ready for authorization checks - """ - assignments = await cls.get_effective_assignments( - session=session, - user=user, - group_membership_service=group_membership_service, - ) - - return cls( - user_id=user.id, - username=user.username, - oauth_provider=user.oauth_provider, - role_assignments=assignments, - ) - - @classmethod - async def get_effective_assignments( - cls, - session: AsyncSession, - user: User, - group_membership_service: GroupMembershipService | None = None, - ) -> List[RoleAssignment]: - """ - Get all effective role assignments for a user (direct + group-based). - - This function: - 1. Takes user's direct role_assignments - 2. Calls GroupMembershipService to get groups (LDAP/local/etc.) - 3. Loads those groups' role_assignments from DJ database - 4. Returns flattened list - - Args: - session: Database session - user: User to get assignments for - group_membership_service: Optional service override - - Returns: - Flat list of all role assignments that apply to this user - """ - from datajunction_server.database.rbac import Role as RoleModel - - # Start with user's direct assignments - assignments = list(user.role_assignments) - - # Get group membership service - if group_membership_service is None: - group_membership_service = get_group_membership_service() - - # Get groups from service (could be LDAP, local DB, etc.) - group_usernames = await group_membership_service.get_user_groups( - session, - user.username, - ) - - if not group_usernames: - return assignments # No groups - - # Load groups from DJ database with their role_assignments - stmt = ( - select(User) - .where(User.username.in_(group_usernames)) - .options( - selectinload(User.role_assignments) - .selectinload(RoleAssignment.role) - .selectinload(RoleModel.scopes), - ) - ) - result = await session.execute(stmt) - groups = result.scalars().all() - - # Flatten group assignments into the list - for group in groups: - assignments.extend(group.role_assignments) - - return assignments - - -# ============================================================================ -# New FastAPI-style Authorization Service -# ============================================================================ - - -class AuthorizationService(ABC): - """ - Abstract base class for authorization strategies. - - Authorization is performed on a pre-loaded authorization context. - - Implementations of this base class decide exactly how to authorize requests: - - RBACAuthorizationService: Uses pre-loaded roles/scopes (default) - - PassthroughAuthorizationService: Always approve (testing/permissive) - - Custom: Your own authorization logic - - Each implementation should define a `name` class attribute to register itself. - """ - - name: str # Subclasses must define this - - @abstractmethod - def authorize( - self, - auth_context: AuthContext, - requests: list[ResourceRequest], - ) -> list[AccessDecision]: - """ - Authorize resource requests for a user. - - This method should mutate the `approved` field on each request - to indicate whether access is granted. - - Args: - auth_context: Pre-loaded authorization context with all needed data - requests: List of resource requests to authorize - - Returns: - The same list of requests with approved=True/False set on each - """ - - -class RBACAuthorizationService(AuthorizationService): - """ - Default RBAC implementation using pre-loaded roles and scopes. - - This implementation: - 1. Works on AuthContext with pre-loaded role_assignments (direct + groups) - 2. Falls back to default_access_policy if no explicit rule exists - 3. Respects role expiration - 4. Synchronous - works on eagerly loaded data - - Group Membership Integration: - - Supports pluggable GroupMembershipService (LDAP, local DB, etc.) - - Groups are loaded when building AuthContext via from_user() - - No DB queries during authorization - all data pre-loaded - """ - - name = "rbac" - - PERMISSION_HIERARCHY = { - ResourceAction.MANAGE: { - ResourceAction.MANAGE, - ResourceAction.DELETE, - ResourceAction.WRITE, - ResourceAction.EXECUTE, - ResourceAction.READ, - }, - ResourceAction.DELETE: { - ResourceAction.DELETE, - ResourceAction.WRITE, - ResourceAction.READ, - }, - ResourceAction.WRITE: { - ResourceAction.WRITE, - ResourceAction.READ, - }, - ResourceAction.EXECUTE: { - ResourceAction.EXECUTE, - ResourceAction.READ, - }, - ResourceAction.READ: { - ResourceAction.READ, - }, - } - - def authorize( - self, - auth_context: AuthContext, - requests: list[ResourceRequest], - ) -> list[AccessDecision]: - """ - Authorize using pre-loaded RBAC roles and scopes (sync). - - Args: - auth_context: Pre-loaded authorization context with role assignments - requests: Resource requests to authorize - - Returns: - Same list of requests with approved=True/False set - """ - return [self._make_decision(auth_context, request) for request in requests] - - def _make_decision( - self, - auth_context: AuthContext, - request: ResourceRequest, - ) -> AccessDecision: - """ - Convert ResourceRequest to AccessDecision. - """ - has_grant = self.has_permission( - assignments=auth_context.role_assignments, - action=request.verb, - resource_type=request.access_object.resource_type, - resource_name=request.access_object.name, - ) - return AccessDecision( - request=request, - approved=(has_grant or settings.default_access_policy == "permissive"), - ) - - @classmethod - def resource_matches_pattern(cls, resource_name: str, pattern: str) -> bool: - """ - Check if resource name matches a pattern with wildcard support. - - resource_matches_pattern("finance.revenue", "finance.*") --> True - resource_matches_pattern("finance.quarterly.revenue", "finance.*") --> True - resource_matches_pattern("users.alice.dashboard", "users.alice.*") --> True - resource_matches_pattern("marketing.revenue", "finance.*") --> False - resource_matches_pattern("anything", "*") --> True - resource_matches_pattern("finance", "finance.*") --> False - """ - if pattern == "*": - return True # Match everything - - if "*" not in pattern: - return resource_name == pattern # Exact match - - # Wildcard pattern: finance.* matches finance.revenue and finance.quarterly.revenue - # But NOT just "finance" (must have something after the dot) - pattern_prefix = pattern.rstrip("*").rstrip(SEPARATOR) - - if not pattern_prefix: - return True # Pattern was just "*" - - # Resource must start with pattern_prefix followed by a dot - # (not an exact match to pattern_prefix, that would be handled by exact pattern) - return resource_name.startswith(pattern_prefix + SEPARATOR) - - @classmethod - def has_permission( - cls, - assignments: List, - action: ResourceAction, - resource_type: ResourceType, - resource_name: str, - ) -> bool: - """ - Determine if a list of role assignments grants the requested permission. - - This method iterates through all provided role assignments, checking if any - grant the specified action on the given resource. Expired assignments are - automatically skipped. Returns True if at least one valid assignment grants - access, False otherwise. - - Args: - assignments: List of role assignments to check - action: The action being requested (READ, WRITE, etc.) - resource_type: Type of resource (NODE, NAMESPACE, etc.) - resource_name: Full name/identifier of the resource - - Returns: - True if permission is granted, False otherwise - """ - for assignment in assignments: - # Skip expired assignments - if assignment.expires_at and assignment.expires_at < datetime.now( - timezone.utc, - ): - continue - - # Check each scope in the role - for scope in assignment.role.scopes: - # Check if scope grants permission for this resource - if cls._scope_grants_permission( - scope, - action, - resource_type, - resource_name, - ): - return True - - return False - - @classmethod - def _scope_grants_permission( - cls, - scope, - action: ResourceAction, - resource_type: ResourceType, - resource_name: str, - ) -> bool: - """ - Check if a scope grants permission for a resource. - - Handles: - 1. Permission hierarchy (MANAGE > DELETE > WRITE > READ, EXECUTE > READ) - 2. Empty/None scope_value or "*" = global access - 3. Wildcard pattern matching (finance.*) - 4. Cross-resource-type: namespace scope covers nodes in that namespace - """ - # Check permission hierarchy: does scope.action grant the requested action? - granted_actions = cls.PERMISSION_HIERARCHY.get(scope.action, {scope.action}) - if action not in granted_actions: - return False - - # Handle global access (empty string, None, or "*" scope_value) - if not scope.scope_value or scope.scope_value == "" or scope.scope_value == "*": - # Global scope matches any resource of the same type - return scope.scope_type == resource_type - - # Same resource type - use pattern matching - if scope.scope_type == resource_type: - return cls.resource_matches_pattern(resource_name, scope.scope_value) - - # Cross-resource-type: namespace scope can cover nodes - if ( - scope.scope_type == ResourceType.NAMESPACE - and resource_type == ResourceType.NODE - ): - # Check if node name matches the namespace pattern - return cls.resource_matches_pattern(resource_name, scope.scope_value) - - # No match - return False - - -class PassthroughAuthorizationService(AuthorizationService): - """ - Always approves all requests without checking permissions. - - Useful for: - - Local development - - Testing - - Fully permissive deployments - - Gradual RBAC rollout (start permissive, add rules incrementally) - """ - - name = "passthrough" - - def authorize( - self, - auth_context: AuthContext, - requests: list[ResourceRequest], - ) -> list[AccessDecision]: - """Approve all requests without checks (sync).""" - return [AccessDecision(request=request, approved=True) for request in requests] - - -@lru_cache(maxsize=None) -def get_authorization_service() -> AuthorizationService: - """ - Factory function to get the configured authorization service. - - This is used as a FastAPI dependency. The service can be overridden - via app.dependency_overrides for testing or custom deployments. - - Built-in providers: - - "rbac": Role-based access control using roles/scopes tables (default) - - "passthrough": Always approve all requests - - Configure via environment variable: - ```bash - AUTHORIZATION_PROVIDER=rbac # or passthrough - ``` - - Custom providers can be added by: - 1. Subclassing AuthorizationService - 2. Defining a `name` class attribute - 3. Importing the class before app starts - - Example: - ```python - class ExampleAuthService(AuthorizationService): - name = "example" - - def authorize(self, user, requests): - # Your sync authorization logic - return requests - ``` - - Returns: - AuthorizationService implementation - - Raises: - ValueError: If the configured provider is unknown - """ - provider = getattr(settings, "authorization_provider", "rbac") - - # Discover all subclasses - providers = {} - for subclass in AuthorizationService.__subclasses__(): - if hasattr(subclass, "name"): - providers[subclass.name] = subclass - if subclass.name == provider: - return subclass() # type: ignore[abstract] - - available = ", ".join(sorted(providers.keys())) - raise ValueError( - f"Unknown authorization_provider: '{provider}'. " - f"Available providers: {available}", - ) - - -async def get_auth_context( - session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), -) -> AuthContext: - """Build authorization context with user + group assignments.""" - return await AuthContext.from_user(session, current_user) - - -class AccessChecker: - """Collects authorization requests and validates them.""" - - def __init__(self, auth_context: AuthContext): - self.auth_context = auth_context - self.requests: list[ResourceRequest] = [] - - def add_request(self, request: ResourceRequest): - """Add a request to check.""" - self.requests.append(request) - - def add_requests(self, requests: list[ResourceRequest]): - """Add requests to check.""" - self.requests.extend(requests) - - @classmethod - def resource_request_from_node( - cls, - node: Node, - action: ResourceAction, - ) -> ResourceRequest: - """Create ResourceRequest from a Node.""" - return ResourceRequest( - verb=action, - access_object=Resource.from_node(node), - ) - - def add_request_by_node_name(self, node_name: str, action: ResourceAction): - """Add request by node name.""" - self.requests.append( - ResourceRequest( - verb=action, - access_object=Resource(name=node_name, resource_type=ResourceType.NODE), - ), - ) - - def add_node(self, node: Node, action: ResourceAction): - """Add request for a node.""" - node_request = self.resource_request_from_node(node, action) - self.add_request(node_request) - - def add_nodes(self, nodes: list[Node], action: ResourceAction): - """Add requests for multiple nodes.""" - self.requests.extend( - self.resource_request_from_node(node, action) for node in nodes - ) - - @classmethod - def resource_request_from_namespace( - cls, - namespace: str, - action: ResourceAction, - ) -> ResourceRequest: - """Create ResourceRequest from a namespace.""" - return ResourceRequest( - verb=action, - access_object=Resource.from_namespace(namespace), - ) - - def add_namespace(self, namespace: str, action: ResourceAction): - """Add request for a namespace.""" - namespace_request = self.resource_request_from_namespace(namespace, action) - self.add_request(namespace_request) - - def add_namespaces(self, namespaces: list[str], action: ResourceAction): - """Add requests for multiple namespaces.""" - self.requests.extend( - self.resource_request_from_namespace(namespace, action) - for namespace in namespaces - ) - - async def check( - self, - on_denied: AccessDenialMode = AccessDenialMode.FILTER, - ) -> list[AccessDecision]: - """ - Validate all requests using AuthorizationService. - - Args: - on_denied: How to handle denied requests - - FILTER: Return only approved (default) - - RAISE: Raise exception if any denied - - RETURN_ALL: Return all with approved field set - """ - auth_service = get_authorization_service() - access_decisions = auth_service.authorize(self.auth_context, self.requests) - - if on_denied == AccessDenialMode.RETURN: - return access_decisions - elif on_denied == AccessDenialMode.RAISE: - denied: list[AccessDecision] = [ - decision for decision in access_decisions if not decision.approved - ] - if denied: - from datajunction_server.errors import DJAuthorizationException - - # Show first 5 denied resources - denied_names = [d.request.access_object.name for d in denied[:5]] - more_count = max(0, len(denied) - 5) - - raise DJAuthorizationException( - message=( - f"Access denied to {len(denied)} resource(s): " - f"{', '.join(denied_names)}" - + (f" and {more_count} more" if more_count else "") - ), - ) - return access_decisions - # Default: FILTER - return [decision for decision in access_decisions if decision.approved] - - async def approved_resource_names(self) -> list[str]: - """Get approved resource names.""" - return [ - decision.request.access_object.name - for decision in await self.check(on_denied=AccessDenialMode.FILTER) - ] - - -def get_access_checker( - auth_context: AuthContext = Depends(get_auth_context), -) -> AccessChecker: - """Provide AccessChecker with pre-loaded context.""" - return AccessChecker(auth_context) diff --git a/datajunction-server/datajunction_server/internal/access/authorization/__init__.py b/datajunction-server/datajunction_server/internal/access/authorization/__init__.py new file mode 100644 index 000000000..7daa449cb --- /dev/null +++ b/datajunction-server/datajunction_server/internal/access/authorization/__init__.py @@ -0,0 +1,31 @@ +"""All authorization functions.""" + +__all__ = [ + "AuthContext", + "get_auth_context", + "AccessChecker", + "get_access_checker", + "AccessDenialMode", + "AuthorizationService", + "RBACAuthorizationService", + "PassthroughAuthorizationService", + "get_authorization_service", +] + +from datajunction_server.internal.access.authorization.context import ( + AuthContext, + get_auth_context, +) + +from datajunction_server.internal.access.authorization.validator import ( + AccessChecker, + get_access_checker, + AccessDenialMode, +) + +from datajunction_server.internal.access.authorization.service import ( + AuthorizationService, + RBACAuthorizationService, + PassthroughAuthorizationService, + get_authorization_service, +) diff --git a/datajunction-server/datajunction_server/internal/access/authorization/context.py b/datajunction-server/datajunction_server/internal/access/authorization/context.py new file mode 100644 index 000000000..7eec35b43 --- /dev/null +++ b/datajunction-server/datajunction_server/internal/access/authorization/context.py @@ -0,0 +1,130 @@ +""" +Authorization context for a user, pre-loaded with all roles. +""" + +from fastapi import Depends +from dataclasses import dataclass +from typing import List, Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + + +from datajunction_server.internal.access.group_membership import ( + get_group_membership_service, +) +from datajunction_server.database.rbac import RoleAssignment, Role +from datajunction_server.database.user import User +from datajunction_server.utils import ( + get_current_user, + get_session, + get_settings, +) + +settings = get_settings() + + +@dataclass(frozen=True) +class AuthContext: + """ + Authorization context for a user. + + Contains all data needed to make authorization decisions, + pre-loaded and ready for fast in-memory checks. + + This separates authorization data from the full User model, + allowing for clean caching, testing, and type safety. + """ + + user_id: int + username: str + oauth_provider: Optional[str] + role_assignments: List[RoleAssignment] # Direct + groups, flattened + + @classmethod + async def from_user( + cls, + session: AsyncSession, + user: User, + ) -> "AuthContext": + """ + Build authorization context from a User object. + + This loads all effective role assignments (direct + group-based) + for the user using the configured GroupMembershipService. + + Args: + session: db session + user: user to build context for + + Returns: + AuthContext ready for authorization checks + """ + assignments = await cls.get_effective_assignments( + session=session, + user=user, + ) + + return cls( + user_id=user.id, + username=user.username, + oauth_provider=user.oauth_provider, + role_assignments=assignments, + ) + + @classmethod + async def get_effective_assignments( + cls, + session: AsyncSession, + user: User, + ) -> List[RoleAssignment]: + """ + Get all effective role assignments for a user (direct + group-based). + + Args: + session: db session + user: user to get assignments for + Returns: + list of all role assignments that apply to this user + """ + group_membership_service = get_group_membership_service() + + # Start with user's direct assignments + assignments = list(user.role_assignments) + + # Get groups from service (could be LDAP, local DB, etc.) + group_usernames = await group_membership_service.get_user_groups( + session, + user.username, + ) + + if not group_usernames: + return assignments # No groups + + # Load groups from DJ database with their role_assignments + stmt = ( + select(User) + .where(User.username.in_(group_usernames)) + .options( + selectinload(User.role_assignments) + .selectinload(RoleAssignment.role) + .selectinload(Role.scopes), + ) + ) + result = await session.execute(stmt) + groups = result.scalars().all() + + # Flatten group assignments into the list + for group in groups: + assignments.extend(group.role_assignments) + + return assignments + + +async def get_auth_context( + session: AsyncSession = Depends(get_session), + current_user: User = Depends(get_current_user), +) -> AuthContext: + """Build authorization context with user + group assignments.""" + return await AuthContext.from_user(session, current_user) diff --git a/datajunction-server/datajunction_server/internal/access/authorization/service.py b/datajunction-server/datajunction_server/internal/access/authorization/service.py new file mode 100644 index 000000000..6abd28f12 --- /dev/null +++ b/datajunction-server/datajunction_server/internal/access/authorization/service.py @@ -0,0 +1,329 @@ +""" +Authorization service implementations for access control. +""" + +from abc import ABC, abstractmethod +from datetime import datetime, timezone +from functools import lru_cache +from typing import List + + +from datajunction_server.models.access import ( + AccessDecision, + ResourceAction, + ResourceRequest, + ResourceType, +) +from datajunction_server.internal.access.authorization.context import ( + AuthContext, +) +from datajunction_server.utils import ( + SEPARATOR, + get_settings, +) + +settings = get_settings() + + +class AuthorizationService(ABC): + """ + Abstract base class for authorization strategies. + + Authorization is performed on a pre-loaded authorization context. + + Implementations of this base class decide exactly how to authorize requests: + - RBACAuthorizationService: Uses pre-loaded roles/scopes (default) + - PassthroughAuthorizationService: Always approve (testing/permissive) + - Custom: Your own authorization logic + + Each implementation should define a `name` class attribute to register itself. + """ + + name: str # Subclasses must define this + + @abstractmethod + def authorize( + self, + auth_context: AuthContext, + requests: list[ResourceRequest], + ) -> list[AccessDecision]: + """ + Authorize resource requests for a user. + + This method should mutate the `approved` field on each request + to indicate whether access is granted. + + Args: + auth_context: Pre-loaded authorization context with all needed data + requests: List of resource requests to authorize + + Returns: + The same list of requests with approved=True/False set on each + """ + + +class RBACAuthorizationService(AuthorizationService): + """ + Default RBAC implementation using pre-loaded roles and scopes. + + This implementation: + 1. Works on AuthContext with pre-loaded role_assignments (direct + groups) + 2. Falls back to default_access_policy if no explicit rule exists + 3. Respects role expiration + 4. Synchronous - works on eagerly loaded data + + Group Membership Integration: + - Supports pluggable GroupMembershipService (LDAP, local DB, etc.) + - Groups are loaded when building AuthContext via from_user() + - No DB queries during authorization - all data pre-loaded + """ + + name = "rbac" + + PERMISSION_HIERARCHY = { + ResourceAction.MANAGE: { + ResourceAction.MANAGE, + ResourceAction.DELETE, + ResourceAction.WRITE, + ResourceAction.EXECUTE, + ResourceAction.READ, + }, + ResourceAction.DELETE: { + ResourceAction.DELETE, + ResourceAction.WRITE, + ResourceAction.READ, + }, + ResourceAction.WRITE: { + ResourceAction.WRITE, + ResourceAction.READ, + }, + ResourceAction.EXECUTE: { + ResourceAction.EXECUTE, + ResourceAction.READ, + }, + ResourceAction.READ: { + ResourceAction.READ, + }, + } + + def authorize( + self, + auth_context: AuthContext, + requests: list[ResourceRequest], + ) -> list[AccessDecision]: + """ + Authorize using pre-loaded RBAC roles and scopes (sync). + + Args: + auth_context: Pre-loaded authorization context with role assignments + requests: Resource requests to authorize + + Returns: + Same list of requests with approved=True/False set + """ + return [self._make_decision(auth_context, request) for request in requests] + + def _make_decision( + self, + auth_context: AuthContext, + request: ResourceRequest, + ) -> AccessDecision: + """ + Convert ResourceRequest to AccessDecision. + """ + has_grant = self.has_permission( + assignments=auth_context.role_assignments, + action=request.verb, + resource_type=request.access_object.resource_type, + resource_name=request.access_object.name, + ) + return AccessDecision( + request=request, + approved=(has_grant or settings.default_access_policy == "permissive"), + ) + + @classmethod + def resource_matches_pattern(cls, resource_name: str, pattern: str) -> bool: + """ + Check if resource name matches a pattern with wildcard support. + + resource_matches_pattern("finance.revenue", "finance.*") --> True + resource_matches_pattern("finance.quarterly.revenue", "finance.*") --> True + resource_matches_pattern("users.alice.dashboard", "users.alice.*") --> True + resource_matches_pattern("marketing.revenue", "finance.*") --> False + resource_matches_pattern("anything", "*") --> True + resource_matches_pattern("finance", "finance.*") --> False + """ + if pattern == "*": + return True # Match everything + + if "*" not in pattern: + return resource_name == pattern # Exact match + + # Wildcard pattern: finance.* matches finance.revenue and finance.quarterly.revenue + # But NOT just "finance" (must have something after the dot) + pattern_prefix = pattern.rstrip("*").rstrip(SEPARATOR) + + if not pattern_prefix: + return True # Pattern was just "*" + + # Resource must start with pattern_prefix followed by a dot + # (not an exact match to pattern_prefix, that would be handled by exact pattern) + return resource_name.startswith(pattern_prefix + SEPARATOR) + + @classmethod + def has_permission( + cls, + assignments: List, + action: ResourceAction, + resource_type: ResourceType, + resource_name: str, + ) -> bool: + """ + Determine if a list of role assignments grants the requested permission. + + This method iterates through all provided role assignments, checking if any + grant the specified action on the given resource. Expired assignments are + automatically skipped. Returns True if at least one valid assignment grants + access, False otherwise. + + Args: + assignments: List of role assignments to check + action: The action being requested (READ, WRITE, etc.) + resource_type: Type of resource (NODE, NAMESPACE, etc.) + resource_name: Full name/identifier of the resource + + Returns: + True if permission is granted, False otherwise + """ + for assignment in assignments: + # Skip expired assignments + if assignment.expires_at and assignment.expires_at < datetime.now( + timezone.utc, + ): + continue + + # Check each scope in the role + for scope in assignment.role.scopes: + # Check if scope grants permission for this resource + if cls._scope_grants_permission( + scope, + action, + resource_type, + resource_name, + ): + return True + + return False + + @classmethod + def _scope_grants_permission( + cls, + scope, + action: ResourceAction, + resource_type: ResourceType, + resource_name: str, + ) -> bool: + """ + Check if a scope grants permission for a resource. + + Handles: + 1. Permission hierarchy (MANAGE > DELETE > WRITE > READ, EXECUTE > READ) + 2. Empty/None scope_value or "*" = global access + 3. Wildcard pattern matching (finance.*) + 4. Cross-resource-type: namespace scope covers nodes in that namespace + """ + # Check permission hierarchy: does scope.action grant the requested action? + granted_actions = cls.PERMISSION_HIERARCHY.get(scope.action, {scope.action}) + if action not in granted_actions: + return False + + # Handle global access (empty string, None, or "*" scope_value) + if not scope.scope_value or scope.scope_value == "" or scope.scope_value == "*": + # Global scope matches any resource of the same type + return scope.scope_type == resource_type + + # Same resource type - use pattern matching + if scope.scope_type == resource_type: + return cls.resource_matches_pattern(resource_name, scope.scope_value) + + # Cross-resource-type: namespace scope can cover nodes + if ( + scope.scope_type == ResourceType.NAMESPACE + and resource_type == ResourceType.NODE + ): + # Check if node name matches the namespace pattern + return cls.resource_matches_pattern(resource_name, scope.scope_value) + + # No match + return False + + +class PassthroughAuthorizationService(AuthorizationService): + """ + Always approves all requests (for testing or permissive environments). + """ + + name = "passthrough" + + def authorize( + self, + auth_context: AuthContext, + requests: list[ResourceRequest], + ) -> list[AccessDecision]: + """Approve all requests without checks (sync).""" + return [AccessDecision(request=request, approved=True) for request in requests] + + +@lru_cache(maxsize=None) +def get_authorization_service() -> AuthorizationService: + """ + Factory function to get the configured authorization service. + + This is used as a FastAPI dependency. The service can be overridden + via app.dependency_overrides for testing or custom deployments. + + Built-in providers: + - "rbac": Role-based access control using roles/scopes tables (default) + - "passthrough": Always approve all requests + + Configure via environment variable: + ```bash + AUTHORIZATION_PROVIDER=rbac # or passthrough + ``` + + Custom providers can be added by: + 1. Subclassing AuthorizationService + 2. Defining a `name` class attribute + 3. Importing the class before app starts + + Example: + ```python + class ExampleAuthService(AuthorizationService): + name = "example" + + def authorize(self, user, requests): + # Your sync authorization logic + return requests + ``` + + Returns: + AuthorizationService implementation + + Raises: + ValueError: If the configured provider is unknown + """ + provider = getattr(settings, "authorization_provider", "rbac") + + # Discover all subclasses + providers = {} + for subclass in AuthorizationService.__subclasses__(): + providers[subclass.name] = subclass + if subclass.name == provider: + return subclass() # type: ignore[abstract] + + available = ", ".join(sorted(providers.keys())) + raise ValueError( + f"Unknown authorization_provider: '{provider}'. " + f"Available providers: {available}", + ) diff --git a/datajunction-server/datajunction_server/internal/access/authorization/validator.py b/datajunction-server/datajunction_server/internal/access/authorization/validator.py new file mode 100644 index 000000000..4ffab7dd0 --- /dev/null +++ b/datajunction-server/datajunction_server/internal/access/authorization/validator.py @@ -0,0 +1,164 @@ +""" +Access validation collection and helper functions. +""" + +from fastapi import Depends +from enum import Enum + + +from datajunction_server.internal.access.authorization.service import ( + get_authorization_service, +) +from datajunction_server.database.node import Node +from datajunction_server.models.access import ( + AccessDecision, + Resource, + ResourceAction, + ResourceRequest, + ResourceType, +) +from datajunction_server.internal.access.authorization.context import ( + AuthContext, + get_auth_context, +) +from datajunction_server.utils import ( + get_settings, +) + +settings = get_settings() + + +class AccessDenialMode(Enum): + """ + How to handle denied access requests. + """ + + FILTER = "filter" # Return only approved requests + RAISE = "raise" # Raise exception if any denied + RETURN = "return" # Return all requests with approved field set + + +class AccessChecker: + """Collects authorization requests and validates them.""" + + def __init__(self, auth_context: AuthContext): + self.auth_context = auth_context + self.requests: list[ResourceRequest] = [] + + def add_request(self, request: ResourceRequest): + """Add a request to check.""" + self.requests.append(request) + + def add_requests(self, requests: list[ResourceRequest]): + """Add requests to check.""" + self.requests.extend(requests) + + @classmethod + def resource_request_from_node( + cls, + node: Node, + action: ResourceAction, + ) -> ResourceRequest: + """Create ResourceRequest from a Node.""" + return ResourceRequest( + verb=action, + access_object=Resource.from_node(node), + ) + + def add_request_by_node_name(self, node_name: str, action: ResourceAction): + """Add request by node name.""" + self.requests.append( + ResourceRequest( + verb=action, + access_object=Resource(name=node_name, resource_type=ResourceType.NODE), + ), + ) + + def add_node(self, node: Node, action: ResourceAction): + """Add request for a node.""" + node_request = self.resource_request_from_node(node, action) + self.add_request(node_request) + + def add_nodes(self, nodes: list[Node], action: ResourceAction): + """Add requests for multiple nodes.""" + self.requests.extend( + self.resource_request_from_node(node, action) for node in nodes + ) + + @classmethod + def resource_request_from_namespace( + cls, + namespace: str, + action: ResourceAction, + ) -> ResourceRequest: + """Create ResourceRequest from a namespace.""" + return ResourceRequest( + verb=action, + access_object=Resource.from_namespace(namespace), + ) + + def add_namespace(self, namespace: str, action: ResourceAction): + """Add request for a namespace.""" + namespace_request = self.resource_request_from_namespace(namespace, action) + self.add_request(namespace_request) + + def add_namespaces(self, namespaces: list[str], action: ResourceAction): + """Add requests for multiple namespaces.""" + self.requests.extend( + self.resource_request_from_namespace(namespace, action) + for namespace in namespaces + ) + + async def check( + self, + on_denied: AccessDenialMode = AccessDenialMode.FILTER, + ) -> list[AccessDecision]: + """ + Validate all requests using AuthorizationService. + + Args: + on_denied: How to handle denied requests + - FILTER: Return only approved (default) + - RAISE: Raise exception if any denied + - RETURN_ALL: Return all with approved field set + """ + auth_service = get_authorization_service() + access_decisions = auth_service.authorize(self.auth_context, self.requests) + + if on_denied == AccessDenialMode.RETURN: + return access_decisions + elif on_denied == AccessDenialMode.RAISE: + denied: list[AccessDecision] = [ + decision for decision in access_decisions if not decision.approved + ] + if denied: + from datajunction_server.errors import DJAuthorizationException + + # Show first 5 denied resources + denied_names = [d.request.access_object.name for d in denied[:5]] + more_count = max(0, len(denied) - 5) + + raise DJAuthorizationException( + message=( + f"Access denied to {len(denied)} resource(s): " + f"{', '.join(denied_names)}" + + (f" and {more_count} more" if more_count else "") + ), + ) + return access_decisions + # Default: FILTER + return [decision for decision in access_decisions if decision.approved] + + async def approved_resource_names(self) -> list[str]: + """Get approved resource names.""" + return [ + decision.request.access_object.name + for decision in await self.check(on_denied=AccessDenialMode.FILTER) + ] + + +def get_access_checker( + auth_context: AuthContext = Depends(get_auth_context), +) -> AccessChecker: + """Provide AccessChecker with pre-loaded context.""" + return AccessChecker(auth_context) diff --git a/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py b/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py index d63b1e57e..0dc70501f 100644 --- a/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py +++ b/datajunction-server/datajunction_server/internal/caching/query_cache_manager.py @@ -64,7 +64,6 @@ async def build_access_checker_from_request( session: AsyncSession, ) -> AccessChecker: """Helper to build checker from request + session.""" - print("Building access checker from request") current_user = await get_current_user(request) auth_context = await AuthContext.from_user(session, current_user) return get_access_checker(auth_context) diff --git a/datajunction-server/datajunction_server/internal/nodes.py b/datajunction-server/datajunction_server/internal/nodes.py index 629522223..19254077f 100644 --- a/datajunction-server/datajunction_server/internal/nodes.py +++ b/datajunction-server/datajunction_server/internal/nodes.py @@ -15,7 +15,6 @@ from datajunction_server.internal.access.authorization import ( AccessChecker, - AccessDenialMode, ) from datajunction_server.internal.caching.interface import Cache from datajunction_server.models.query import QueryCreate @@ -58,7 +57,6 @@ ) from datajunction_server.internal.history import ActivityType, EntityType from datajunction_server.internal.validation import NodeValidator, validate_node_data -from datajunction_server.models import access from datajunction_server.models.attribute import ( AttributeTypeIdentifier, ColumnAttributes, @@ -912,19 +910,6 @@ async def update_any_node( ) node = cast(Node, node) - # Check that the user has access to modify this node - if access_checker: - access_checker.add_request( - access.ResourceRequest( - access_object=access.Resource( - resource_type=access.ResourceType.NODE, - name=node.name, - ), - verb=access.ResourceAction.WRITE, - ), - ) - await access_checker.check(on_denied=AccessDenialMode.RAISE) - if data.owners and data.owners != [owner.username for owner in node.owners]: await update_owners(session, node, data.owners, current_user, save_history) diff --git a/datajunction-server/datajunction_server/internal/sql.py b/datajunction-server/datajunction_server/internal/sql.py index 8ee18c082..da0ab8a82 100644 --- a/datajunction-server/datajunction_server/internal/sql.py +++ b/datajunction-server/datajunction_server/internal/sql.py @@ -293,7 +293,7 @@ async def build_sql_for_multiple_metrics( ) # Check authorization for all discovered nodes - if access_checker: + if access_checker: # pragma: no cover await access_checker.check(on_denied=AccessDenialMode.RAISE) columns = [ diff --git a/datajunction-server/tests/api/access_test.py b/datajunction-server/tests/api/access_test.py index e02adcb52..64ba10f69 100644 --- a/datajunction-server/tests/api/access_test.py +++ b/datajunction-server/tests/api/access_test.py @@ -1,5 +1,5 @@ """ -Tests for the data API. +Tests for access control across APIs. """ from http import HTTPStatus @@ -8,6 +8,7 @@ from datajunction_server.internal.access.authorization import AuthorizationService from datajunction_server.models import access +from datajunction_server.models.access import ResourceType class DenyAllAuthorizationService(AuthorizationService): @@ -24,6 +25,51 @@ def authorize(self, auth_context, requests): ] +class NamespaceOnlyAuthorizationService(AuthorizationService): + """ + Authorization service that allows namespace access but denies all node access. + """ + + name = "namespace_only" + + def __init__(self, allowed_namespaces: list[str]): + self.allowed_namespaces = allowed_namespaces + + def authorize(self, auth_context, requests): + decisions = [] + for request in requests: + approved = False + if request.access_object.resource_type == ResourceType.NAMESPACE: + # Allow access to specified namespaces + approved = request.access_object.name in self.allowed_namespaces + # Deny all NODE access + decisions.append(access.AccessDecision(request=request, approved=approved)) + return decisions + + +class PartialNodeAuthorizationService(AuthorizationService): + """ + Authorization service that allows access to specific namespaces and nodes. + """ + + name = "partial_node" + + def __init__(self, allowed_namespaces: list[str], allowed_nodes: list[str]): + self.allowed_namespaces = allowed_namespaces + self.allowed_nodes = allowed_nodes + + def authorize(self, auth_context, requests): + decisions = [] + for request in requests: + approved = False + if request.access_object.resource_type == ResourceType.NAMESPACE: + approved = request.access_object.name in self.allowed_namespaces + elif request.access_object.resource_type == ResourceType.NODE: + approved = request.access_object.name in self.allowed_nodes + decisions.append(access.AccessDecision(request=request, approved=approved)) + return decisions + + class TestDataAccessControl: """ Test the data access control. @@ -43,7 +89,7 @@ def get_deny_all_service(): return DenyAllAuthorizationService() mocker.patch( - "datajunction_server.internal.access.authorization.get_authorization_service", + "datajunction_server.internal.access.authorization.validator.get_authorization_service", get_deny_all_service, ) response = await module__client_with_examples.get("/data/basic.num_comments/") @@ -67,7 +113,7 @@ def get_deny_all_service(): return DenyAllAuthorizationService() mocker.patch( - "datajunction_server.internal.access.authorization.get_authorization_service", + "datajunction_server.internal.access.authorization.validator.get_authorization_service", get_deny_all_service, ) @@ -93,3 +139,107 @@ def get_deny_all_service(): assert "Access denied to" in data["message"] assert "foo.bar" in data["message"] assert response.status_code == HTTPStatus.FORBIDDEN + + +class TestNamespaceAccessControl: + """ + Test access control for the ``GET /namespaces/{namespace}/`` endpoint. + """ + + @pytest.mark.asyncio + async def test_list_nodes_with_no_namespace_access( + self, + module__client_with_examples: AsyncClient, + mocker, + ): + """ + User with no namespace READ access should get empty list. + """ + + def get_deny_all_service(): + return DenyAllAuthorizationService() + + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_deny_all_service, + ) + + response = await module__client_with_examples.get("/namespaces/default/") + assert response.status_code == HTTPStatus.OK + data = response.json() + assert data == [] + + @pytest.mark.asyncio + async def test_list_nodes_with_namespace_access_but_no_node_access( + self, + module__client_with_examples: AsyncClient, + mocker, + ): + """ + User with namespace READ access but no node READ access should get empty list. + """ + + def get_namespace_only_service(): + return NamespaceOnlyAuthorizationService(allowed_namespaces=["default"]) + + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_namespace_only_service, + ) + + response = await module__client_with_examples.get("/namespaces/default/") + assert response.status_code == HTTPStatus.OK + data = response.json() + assert data == [] + + @pytest.mark.asyncio + async def test_list_nodes_with_partial_node_access( + self, + module__client_with_examples: AsyncClient, + mocker, + ): + """ + User with namespace access and partial node access should get filtered list. + """ + allowed_nodes = [ + "default.repair_orders", + "default.hard_hat", + ] + + def get_partial_service(): + return PartialNodeAuthorizationService( + allowed_namespaces=["default"], + allowed_nodes=allowed_nodes, + ) + + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + get_partial_service, + ) + + response = await module__client_with_examples.get("/namespaces/default/") + assert response.status_code == HTTPStatus.OK + data = response.json() + + # Should only return the allowed nodes + returned_names = [node["name"] for node in data] + assert set(returned_names) == set(allowed_nodes) + + @pytest.mark.asyncio + async def test_list_nodes_with_full_access( + self, + module__client_with_examples: AsyncClient, + ): + """ + User with full access (PassthroughAuthorizationService) should get all nodes. + Default test client uses PassthroughAuthorizationService. + """ + response = await module__client_with_examples.get("/namespaces/default/") + assert response.status_code == HTTPStatus.OK + data = response.json() + + # Should return multiple nodes (the roads example has many) + assert len(data) > 0 + # Verify we get node details + assert all("name" in node for node in data) + assert all(node["name"].startswith("default.") for node in data) diff --git a/datajunction-server/tests/api/dimensions_access_test.py b/datajunction-server/tests/api/dimensions_access_test.py index 4d18207f3..8c7f1381d 100644 --- a/datajunction-server/tests/api/dimensions_access_test.py +++ b/datajunction-server/tests/api/dimensions_access_test.py @@ -39,7 +39,7 @@ def get_repair_only_service(): return RepairOnlyAuthorizationService() mocker.patch( - "datajunction_server.internal.access.authorization.get_authorization_service", + "datajunction_server.internal.access.authorization.validator.get_authorization_service", get_repair_only_service, ) diff --git a/datajunction-server/tests/api/graphql/common_dimensions_test.py b/datajunction-server/tests/api/graphql/common_dimensions_test.py index 2948c2cc4..74cfdbe34 100644 --- a/datajunction-server/tests/api/graphql/common_dimensions_test.py +++ b/datajunction-server/tests/api/graphql/common_dimensions_test.py @@ -95,7 +95,7 @@ async def test_get_common_dimensions( "role": None, "type": "int", } in data["data"]["commonDimensions"] - assert len(capture_queries) <= 18 # type: ignore + assert len(capture_queries) <= 28 # type: ignore @pytest.mark.asyncio diff --git a/datajunction-server/tests/api/namespaces_test.py b/datajunction-server/tests/api/namespaces_test.py index 0450f9a77..16d1136ba 100644 --- a/datajunction-server/tests/api/namespaces_test.py +++ b/datajunction-server/tests/api/namespaces_test.py @@ -865,7 +865,7 @@ def get_dbt_only_service(): return DbtOnlyAuthorizationService() mocker.patch( - "datajunction_server.internal.access.authorization.get_authorization_service", + "datajunction_server.internal.access.authorization.validator.get_authorization_service", get_dbt_only_service, ) @@ -911,7 +911,7 @@ def get_deny_all_service(): return DenyAllAuthorizationService() mocker.patch( - "datajunction_server.internal.access.authorization.get_authorization_service", + "datajunction_server.internal.access.authorization.validator.get_authorization_service", get_deny_all_service, ) response = await client_with_service_setup.get("/namespaces/") diff --git a/datajunction-server/tests/api/sql_test.py b/datajunction-server/tests/api/sql_test.py index 1ff8923b8..ca4c742f5 100644 --- a/datajunction-server/tests/api/sql_test.py +++ b/datajunction-server/tests/api/sql_test.py @@ -2748,7 +2748,7 @@ def get_deny_all_service(): return DenyAllAuthorizationService() mocker.patch( - "datajunction_server.internal.access.authorization.get_authorization_service", + "datajunction_server.internal.access.authorization.validator.get_authorization_service", get_deny_all_service, ) response = await module__client_with_examples.get( @@ -3397,7 +3397,7 @@ def get_selective_denial_service(): return SelectiveDenialAuthorizationService() mocker.patch( - "datajunction_server.internal.access.authorization.get_authorization_service", + "datajunction_server.internal.access.authorization.validator.get_authorization_service", return_value=SelectiveDenialAuthorizationService(), ) diff --git a/datajunction-server/tests/construction/build_test.py b/datajunction-server/tests/construction/build_test.py index a0b66d2f5..458e1f8e8 100644 --- a/datajunction-server/tests/construction/build_test.py +++ b/datajunction-server/tests/construction/build_test.py @@ -5,7 +5,7 @@ import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession - +import pytest_asyncio import datajunction_server.sql.parsing.types as ct from datajunction_server.construction.build import ( build_materialized_cube_node, @@ -23,10 +23,54 @@ from datajunction_server.models.node_type import NodeType from datajunction_server.naming import amenable_name from datajunction_server.sql.parsing.backends.antlr4 import ast, parse +from datajunction_server.internal.access.authorization.service import ( + AuthorizationService, +) +from datajunction_server.internal.access.authorization.validator import AccessChecker +from datajunction_server.internal.access.authorization.context import AuthContext +from datajunction_server.internal.access.authentication.basic import get_user +from datajunction_server.models.access import AccessDecision + + +class AllowAllAuthorizationService(AuthorizationService): + """ + Custom authorization service that allows all access. + """ + + name = "allow_all" + + def authorize(self, auth_context, requests): + return [AccessDecision(request=request, approved=True) for request in requests] + + +@pytest_asyncio.fixture +async def access_checker( + construction_session: AsyncSession, + default_user: User, + mocker, +) -> AccessChecker: + """ + Fixture to mock access checker to allow all access. + """ + user = await get_user(default_user.username, construction_session) + + def mock_get_allow_all_service(): + return AllowAllAuthorizationService() + + mocker.patch( + "datajunction_server.internal.access.authorization.validator.get_authorization_service", + mock_get_allow_all_service, + ) + return AccessChecker( + await AuthContext.from_user(construction_session, user), + ) @pytest.mark.asyncio -async def test_build_metric_with_dimensions_aggs(construction_session: AsyncSession): +async def test_build_metric_with_dimensions_aggs( + construction_session: AsyncSession, + access_checker: AccessChecker, +): """ Test building metric with dimensions """ @@ -40,6 +84,7 @@ async def test_build_metric_with_dimensions_aggs(construction_session: AsyncSess filters=[], dimensions=["basic.dimension.users.country", "basic.dimension.users.gender"], orderby=[], + access_checker=access_checker, ) expected = """ WITH basic_DOT_source_DOT_comments AS ( @@ -84,6 +129,7 @@ async def test_build_metric_with_dimensions_aggs(construction_session: AsyncSess @pytest.mark.asyncio async def test_build_metric_with_required_dimensions( construction_session: AsyncSession, + access_checker: AccessChecker, ): """ Test building metric with bound dimensions @@ -99,6 +145,7 @@ async def test_build_metric_with_required_dimensions( filters=[], dimensions=["basic.dimension.users.country", "basic.dimension.users.gender"], orderby=[], + access_checker=access_checker, ) expected = """ WITH basic_DOT_source_DOT_comments AS ( @@ -247,7 +294,10 @@ async def test_raise_on_build_without_required_dimension_column( @pytest.mark.asyncio -async def test_build_metric_with_dimensions_filters(construction_session: AsyncSession): +async def test_build_metric_with_dimensions_filters( + construction_session: AsyncSession, + access_checker: AccessChecker, +): """ Test building metric with dimension filters """ @@ -264,6 +314,7 @@ async def test_build_metric_with_dimensions_filters(construction_session: AsyncS ], dimensions=[], orderby=[], + access_checker=access_checker, ) expected = """ WITH basic_DOT_source_DOT_comments AS ( diff --git a/datajunction-server/tests/internal/authorization_test.py b/datajunction-server/tests/internal/authorization_test.py index 85be5092c..5e2b61b92 100644 --- a/datajunction-server/tests/internal/authorization_test.py +++ b/datajunction-server/tests/internal/authorization_test.py @@ -98,6 +98,55 @@ def test_nested_namespace_wildcard(self): "users.alice.*", ) + def test_edge_case_patterns(self): + """Test edge case patterns that reach the fallback logic (line 167-168). + + These patterns are unusual but should be handled gracefully: + - ".*" -> strips to empty string + - "**" -> strips to empty string + These are treated as global wildcards (match everything). + """ + # ".*" pattern - after stripping "*" and ".", becomes empty + assert RBACAuthorizationService.resource_matches_pattern( + "anything", + ".*", + ) + assert RBACAuthorizationService.resource_matches_pattern( + "finance.revenue", + ".*", + ) + assert RBACAuthorizationService.resource_matches_pattern( + "", + ".*", + ) + + # "**" pattern - after stripping "*", becomes empty + assert RBACAuthorizationService.resource_matches_pattern( + "anything", + "**", + ) + assert RBACAuthorizationService.resource_matches_pattern( + "deeply.nested.resource.name", + "**", + ) + + def test_wildcard_in_middle_not_supported(self): + """Test that wildcards in the middle of patterns don't work as expected. + + Note: The current implementation only supports trailing wildcards. + Patterns like "finance.*.revenue" are NOT supported as glob patterns. + """ + # "finance.*.revenue" - contains * but not at end + # This will strip trailing * (none) and compare as prefix + # So it won't match "finance.quarterly.revenue" + assert not RBACAuthorizationService.resource_matches_pattern( + "finance.quarterly.revenue", + "finance.*.revenue", + ) + + # It would only match if resource literally starts with "finance.*.revenue." + # which is unlikely in practice + @pytest.mark.asyncio class TestRBACPermissionChecks: @@ -450,7 +499,7 @@ async def test_rbac_service_with_permissions( ): """Test RBACAuthorizationService with granted permissions.""" mock_settings = mocker.patch( - "datajunction_server.internal.access.authorization.settings", + "datajunction_server.internal.access.authorization.service.settings", ) mock_settings.authorization_provider = "rbac" mock_settings.default_access_policy = "restrictive" @@ -505,7 +554,7 @@ async def test_rbac_service_with_permissions( async def test_get_authorization_service_factory(self, mocker): """Test the factory function returns correct service.""" mock_settings = mocker.patch( - "datajunction_server.internal.access.authorization.settings", + "datajunction_server.internal.access.authorization.service.settings", ) mock_settings.authorization_provider = "rbac" mock_settings.default_access_policy = "restrictive" @@ -590,7 +639,7 @@ async def test_user_inherits_group_permissions( user = await get_user(username=default_user.username, session=session) mock_settings = mocker.patch( - "datajunction_server.internal.access.authorization.settings", + "datajunction_server.internal.access.authorization.service.settings", ) mock_settings.authorization_provider = "rbac" mock_settings.default_access_policy = "restrictive" @@ -663,7 +712,7 @@ async def test_user_no_permission_without_group( user = await get_user(username=default_user.username, session=session) mock_settings = mocker.patch( - "datajunction_server.internal.access.authorization.settings", + "datajunction_server.internal.access.authorization.service.settings", ) mock_settings.authorization_provider = "rbac" mock_settings.default_access_policy = "restrictive" @@ -1473,7 +1522,7 @@ async def test_check_access_filter_mode_returns_only_approved( ] mock_settings = mocker.patch( - "datajunction_server.internal.access.authorization.settings", + "datajunction_server.internal.access.authorization.service.settings", ) mock_settings.authorization_provider = "rbac" mock_settings.default_access_policy = "restrictive" @@ -1487,7 +1536,7 @@ async def test_check_access_filter_mode_returns_only_approved( # Should only return the 2 approved (finance.* nodes) assert len(approved) == 2 - assert approved == {"finance.revenue", "finance.cost"} + assert approved == ["finance.revenue", "finance.cost"] async def test_check_access_raise_mode_throws_on_denial( self, @@ -1510,7 +1559,7 @@ async def test_check_access_raise_mode_throws_on_denial( ) mock_settings = mocker.patch( - "datajunction_server.internal.access.authorization.settings", + "datajunction_server.internal.access.authorization.service.settings", ) mock_settings.authorization_provider = "rbac" mock_settings.default_access_policy = "restrictive" @@ -1523,9 +1572,7 @@ async def test_check_access_raise_mode_throws_on_denial( await access_checker.check(on_denied=AccessDenialMode.RAISE) # Check exception message - assert "Access denied" in str(exc_info.value) - assert "WRITE" in str(exc_info.value) - assert "finance.revenue" in str(exc_info.value) + assert "Access denied to 1 resource(s): finance.revenue" in str(exc_info.value) async def test_check_access_raise_mode_succeeds_when_approved( self, @@ -1629,7 +1676,7 @@ async def test_check_access_return_mode( ] mock_settings = mocker.patch( - "datajunction_server.internal.access.authorization.settings", + "datajunction_server.internal.access.authorization.service.settings", ) mock_settings.authorization_provider = "rbac" mock_settings.default_access_policy = "restrictive" @@ -1804,10 +1851,13 @@ async def remove_user_from_group(self, session, username, group_name): # Use custom service mock_service = MockGroupService() + mocker.patch( + "datajunction_server.internal.access.authorization.context.get_group_membership_service", + lambda: mock_service, + ) assignments = await AuthContext.get_effective_assignments( session, user, - mock_service, ) # Should include mock group's assignment @@ -1931,7 +1981,7 @@ async def test_check_access_with_mixed_approval( ), ] mock_settings = mocker.patch( - "datajunction_server.internal.access.authorization.settings", + "datajunction_server.internal.access.authorization.service.settings", ) mock_settings.authorization_provider = "rbac" mock_settings.default_access_policy = "restrictive" diff --git a/datajunction-server/tests/internal/caching/query_cache_manager_test.py b/datajunction-server/tests/internal/caching/query_cache_manager_test.py index b86796edb..a31773744 100644 --- a/datajunction-server/tests/internal/caching/query_cache_manager_test.py +++ b/datajunction-server/tests/internal/caching/query_cache_manager_test.py @@ -1,4 +1,5 @@ import asyncio +from types import SimpleNamespace from unittest import mock from unittest.mock import patch @@ -13,6 +14,8 @@ QueryRequestParams, ) from datajunction_server.database.queryrequest import QueryBuildType +from datajunction_server.database.user import User, OAuthProvider +from datajunction_server.internal.access.authorization import AccessChecker class DummyRequest: @@ -27,6 +30,15 @@ def __init__(self, cache_control: str | None = None): self.headers = Headers(headers) self.method = "GET" + # Add state with a dummy user for get_current_user + self.state = SimpleNamespace( + user=User( + username="testuser", + email="test@example.com", + oauth_provider=OAuthProvider.BASIC, + ), + ) + @pytest.mark.asyncio async def test_cache_key_prefix_uses_query_type(): @@ -66,10 +78,17 @@ async def test_fallback_calls_get_measures_query(): """ Should call get_measures_query with correct args. """ - with patch( - "datajunction_server.internal.caching.query_cache_manager.get_measures_query", - return_value=[{"sql": "SELECT * FROM test"}], - ) as get_measures_query_mock: + mock_access_checker = mock.AsyncMock(spec=AccessChecker) + with ( + patch( + "datajunction_server.internal.caching.query_cache_manager.get_measures_query", + return_value=[{"sql": "SELECT * FROM test"}], + ) as get_measures_query_mock, + patch( + "datajunction_server.internal.caching.query_cache_manager.build_access_checker_from_request", + return_value=mock_access_checker, + ), + ): cache = CachelibCache() manager = QueryCacheManager(cache, QueryBuildType.MEASURES) params = QueryRequestParams( @@ -89,49 +108,56 @@ async def test_get_or_load_respects_cache_control(): """ Full flow test to ensure Cache-Control is respected. """ - with patch( - "datajunction_server.internal.caching.query_cache_manager.get_measures_query", - return_value=[{"sql": "SELECT * FROM test"}], - ): - with patch( + mock_access_checker = mock.AsyncMock(spec=AccessChecker) + with ( + patch( + "datajunction_server.internal.caching.query_cache_manager.get_measures_query", + return_value=[{"sql": "SELECT * FROM test"}], + ), + patch( "datajunction_server.internal.caching.query_cache_manager.VersionedQueryKey.version_query_request", return_value="versioned123", - ): - cache = CachelibCache() - manager = QueryCacheManager(cache, QueryBuildType.MEASURES) - params = QueryRequestParams( - nodes=["foo"], - dimensions=["dim1"], - filters=[], - ) - - # Put stale value in cache to test hit vs miss - key = await manager.build_cache_key(DummyRequest(), params) - cache.set(key, [{"sql": "CACHED"}]) - - background = BackgroundTasks() - - # `no-cache` => should bypass cache - request = DummyRequest(cache_control="no-cache") - result = await manager.get_or_load(background, request, params) - assert result == [{"sql": "SELECT * FROM test"}] - - # Run tasks, should store - for task in background.tasks: - await task() - assert cache.get(key) == [{"sql": "SELECT * FROM test"}] - - # `no-store` => should hit cache, but not store - cache.set(key, [{"sql": "CACHED"}]) - request = DummyRequest(cache_control="no-store") - result = await manager.get_or_load(background, request, params) - assert result == [{"sql": "CACHED"}] # hits stale - - # `no-cache, no-store` => should always fallback but never store - request = DummyRequest(cache_control="no-cache, no-store") - result = await manager.get_or_load(background, request, params) - assert result == [{"sql": "SELECT * FROM test"}] - cache.get(key) == [{"sql": "CACHED"}] # still stale + ), + patch( + "datajunction_server.internal.caching.query_cache_manager.build_access_checker_from_request", + return_value=mock_access_checker, + ), + ): + cache = CachelibCache() + manager = QueryCacheManager(cache, QueryBuildType.MEASURES) + params = QueryRequestParams( + nodes=["foo"], + dimensions=["dim1"], + filters=[], + ) + + # Put stale value in cache to test hit vs miss + key = await manager.build_cache_key(DummyRequest(), params) + cache.set(key, [{"sql": "CACHED"}]) + + background = BackgroundTasks() + + # `no-cache` => should bypass cache + request = DummyRequest(cache_control="no-cache") + result = await manager.get_or_load(background, request, params) + assert result == [{"sql": "SELECT * FROM test"}] + + # Run tasks, should store + for task in background.tasks: + await task() + assert cache.get(key) == [{"sql": "SELECT * FROM test"}] + + # `no-store` => should hit cache, but not store + cache.set(key, [{"sql": "CACHED"}]) + request = DummyRequest(cache_control="no-store") + result = await manager.get_or_load(background, request, params) + assert result == [{"sql": "CACHED"}] # hits stale + + # `no-cache, no-store` => should always fallback but never store + request = DummyRequest(cache_control="no-cache, no-store") + result = await manager.get_or_load(background, request, params) + assert result == [{"sql": "SELECT * FROM test"}] + cache.get(key) == [{"sql": "CACHED"}] # still stale @pytest.mark.asyncio diff --git a/datajunction-server/tests/models/access_test.py b/datajunction-server/tests/models/access_test.py new file mode 100644 index 000000000..1781e1e32 --- /dev/null +++ b/datajunction-server/tests/models/access_test.py @@ -0,0 +1,128 @@ +""" +Tests for ``datajunction_server.models.access``. +""" + +from datajunction_server.models.access import ( + Resource, + ResourceAction, + ResourceRequest, + ResourceType, +) + + +class TestResource: + """Tests for Resource dataclass""" + + def test_resource_hash(self) -> None: + """Test Resource.__hash__ method (line 43)""" + resource1 = Resource(name="test.node", resource_type=ResourceType.NODE) + resource2 = Resource(name="test.node", resource_type=ResourceType.NODE) + resource3 = Resource(name="other.node", resource_type=ResourceType.NODE) + resource4 = Resource(name="test.node", resource_type=ResourceType.NAMESPACE) + + # Same name and type should have same hash + assert hash(resource1) == hash(resource2) + + # Different name should have different hash + assert hash(resource1) != hash(resource3) + + # Different type should have different hash + assert hash(resource1) != hash(resource4) + + # Resources can be used in sets and dicts + resource_set = {resource1, resource2, resource3} + assert len(resource_set) == 2 # resource1 and resource2 are same + + resource_dict = {resource1: "value1"} + assert resource_dict[resource2] == "value1" # Same hash, same key + + +class TestResourceRequest: + """Tests for ResourceRequest dataclass""" + + def test_resource_request_hash(self) -> None: + """Test ResourceRequest.__hash__ method (line 71)""" + resource = Resource(name="test.node", resource_type=ResourceType.NODE) + + request1 = ResourceRequest(verb=ResourceAction.READ, access_object=resource) + request2 = ResourceRequest(verb=ResourceAction.READ, access_object=resource) + request3 = ResourceRequest(verb=ResourceAction.WRITE, access_object=resource) + + other_resource = Resource(name="other.node", resource_type=ResourceType.NODE) + request4 = ResourceRequest( + verb=ResourceAction.READ, + access_object=other_resource, + ) + + # Same verb and resource should have same hash + assert hash(request1) == hash(request2) + + # Different verb should have different hash + assert hash(request1) != hash(request3) + + # Different resource should have different hash + assert hash(request1) != hash(request4) + + # ResourceRequests can be used in sets and dicts + request_set = {request1, request2, request3} + assert len(request_set) == 2 # request1 and request2 are same + + def test_resource_request_eq(self) -> None: + """Test ResourceRequest.__eq__ method (line 74)""" + resource = Resource(name="test.node", resource_type=ResourceType.NODE) + other_resource = Resource(name="other.node", resource_type=ResourceType.NODE) + + request1 = ResourceRequest(verb=ResourceAction.READ, access_object=resource) + request2 = ResourceRequest(verb=ResourceAction.READ, access_object=resource) + request3 = ResourceRequest(verb=ResourceAction.WRITE, access_object=resource) + request4 = ResourceRequest( + verb=ResourceAction.READ, + access_object=other_resource, + ) + + # Same verb and access_object + assert request1 == request2 + + # Different verb + assert request1 != request3 + + # Different access_object + assert request1 != request4 + + def test_resource_request_str(self) -> None: + """Test ResourceRequest.__str__ method (line 77)""" + node_resource = Resource(name="test.node", resource_type=ResourceType.NODE) + namespace_resource = Resource( + name="test.namespace", + resource_type=ResourceType.NAMESPACE, + ) + + read_request = ResourceRequest( + verb=ResourceAction.READ, + access_object=node_resource, + ) + assert str(read_request) == "read:node/test.node" + + write_request = ResourceRequest( + verb=ResourceAction.WRITE, + access_object=namespace_resource, + ) + assert str(write_request) == "write:namespace/test.namespace" + + execute_request = ResourceRequest( + verb=ResourceAction.EXECUTE, + access_object=node_resource, + ) + assert str(execute_request) == "execute:node/test.node" + + delete_request = ResourceRequest( + verb=ResourceAction.DELETE, + access_object=node_resource, + ) + assert str(delete_request) == "delete:node/test.node" + + manage_request = ResourceRequest( + verb=ResourceAction.MANAGE, + access_object=namespace_resource, + ) + assert str(manage_request) == "manage:namespace/test.namespace" From b18e2a9f6ae84a672100b977821d14441d9ce0f9 Mon Sep 17 00:00:00 2001 From: Yian Shang Date: Wed, 31 Dec 2025 20:08:58 -0800 Subject: [PATCH 4/4] Fix --- .../datajunction_server/api/sql.py | 216 +----------------- .../internal/materializations.py | 2 +- 2 files changed, 3 insertions(+), 215 deletions(-) diff --git a/datajunction-server/datajunction_server/api/sql.py b/datajunction-server/datajunction_server/api/sql.py index c16cc251a..3c49a6fb7 100644 --- a/datajunction-server/datajunction_server/api/sql.py +++ b/datajunction-server/datajunction_server/api/sql.py @@ -9,6 +9,7 @@ from fastapi import BackgroundTasks, Depends, Query, Request from sqlalchemy.ext.asyncio import AsyncSession +from datajunction_server.utils import get_current_user from datajunction_server.construction.build_v3 import ( build_metrics_sql, build_measures_sql, @@ -25,6 +26,7 @@ from datajunction_server.internal.caching.cachelib_cache import get_cache from datajunction_server.internal.caching.interface import Cache from datajunction_server.database import Node +from datajunction_server.database.user import User from datajunction_server.database.queryrequest import QueryBuildType from datajunction_server.errors import DJInvalidInputException from datajunction_server.internal.access.authentication.http import SecureAPIRouter @@ -374,220 +376,6 @@ async def get_metrics_sql_v3( ) -@router.get( - "/sql/measures/v3/", - response_model=MeasuresSQLResponse, - name="Get Measures SQL V3", - tags=["sql", "v3"], -) -async def get_measures_sql_v3( - metrics: List[str] = Query([]), - dimensions: List[str] = Query([]), - filters: List[str] = Query([]), - use_materialized: bool = Query(True), - *, - session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), -) -> MeasuresSQLResponse: - """ - Generate pre-aggregated measures SQL for the requested metrics. - - Measures SQL represents the first stage of metric computation - it decomposes - each metric into its atomic aggregation components (e.g., SUM(amount), COUNT(*)) - and produces SQL that computes these components at the requested dimensional grain. - - Metrics are separated into grain groups, which represent sets of metrics that can be - computed together at a common grain. Each grain group produces its own SQL query, which - can be materialized independently to produce intermediate tables that are then queried - to compute final metric values. - - Returns: - One or more `GrainGroupSQL` objects, each containing: - - SQL query computing metric components at the specified grain - - Column metadata with semantic types - - Component details for downstream re-aggregation - - Args: - use_materialized: If True (default), use materialized tables when available. - Set to False when generating SQL for materialization refresh to avoid - circular references. - - See also: `/sql/metrics/v3/` for the final combined query with metric expressions. - """ - result = await build_measures_sql( - session=session, - metrics=metrics, - dimensions=dimensions, - filters=filters, - dialect=Dialect.SPARK, - use_materialized=use_materialized, - ) - - # Build a unified component_aliases map from all grain groups - # This maps component hash names -> actual SQL column aliases - all_component_aliases: dict[str, str] = {} - for gg in result.grain_groups: - all_component_aliases.update(gg.component_aliases) - - # Build metric formulas from decomposed metrics - metric_formulas = [] - for metric_name, decomposed in result.decomposed_metrics.items(): - # Get the combiner expression and rewrite component names to actual SQL aliases - from copy import deepcopy - - combiner_ast = deepcopy(decomposed.derived_ast.select.projection[0]) - - # Replace component hash names with actual SQL aliases in the combiner - for col in combiner_ast.find_all(ast.Column): - col_name = col.name.name if col.name else None - if col_name and col_name in all_component_aliases: - col.name = ast.Name(all_component_aliases[col_name]) - col._table = None - - combiner_str = str(combiner_ast) - - # Determine parent node name from the first grain group that contains this metric - parent_name = None - for gg in result.grain_groups: - if metric_name in gg.metrics: - parent_name = gg.parent_name - break - - # Check if this is a derived metric (references other metrics) - parent_names = result.ctx.parent_map.get(metric_name, []) - is_derived = decomposed.is_derived_for_parents( - parent_names, - result.ctx.nodes, - ) - - # Get component column names as they appear in SQL - # Use the unified component_aliases to resolve hash names -> actual aliases - component_names = [ - all_component_aliases.get(comp.name, comp.name) - for comp in decomposed.components - ] - - metric_formulas.append( - MetricFormulaResponse( - name=metric_name, - short_name=metric_name.split(".")[-1], - combiner=combiner_str, - components=component_names, - is_derived=is_derived, - parent_name=parent_name, - ), - ) - - return MeasuresSQLResponse( - grain_groups=[ - GrainGroupResponse( - sql=gg.sql, - columns=[ - V3ColumnMetadata( - name=col.name, - type=col.type, - semantic_entity=col.semantic_name, - semantic_type=col.semantic_type, - ) - for col in gg.columns - ], - grain=gg.grain, - aggregability=gg.aggregability.value - if hasattr(gg.aggregability, "value") - else str(gg.aggregability), - metrics=gg.metrics, - components=[ - ComponentResponse( - # Use actual SQL alias (metric short name for single-component, hash for multi) - name=gg.component_aliases.get(comp.name, comp.name), - expression=comp.expression, - aggregation=comp.aggregation, - merge=comp.merge, - aggregability=comp.rule.type.value - if hasattr(comp.rule.type, "value") - else str(comp.rule.type), - ) - for comp in gg.components - ], - parent_name=gg.parent_name, - ) - for gg in result.grain_groups - ], - metric_formulas=metric_formulas, - dialect=str(result.dialect) if result.dialect else None, - requested_dimensions=result.requested_dimensions, - ) - - -@router.get( - "/sql/metrics/v3/", - response_model=V3TranslatedSQL, - name="Get Metrics SQL V3", - tags=["sql", "v3"], -) -async def get_metrics_sql_v3( - metrics: List[str] = Query([]), - dimensions: List[str] = Query([]), - filters: List[str] = Query([]), - *, - session: AsyncSession = Depends(get_session), - current_user: User = Depends(get_current_user), -) -> V3TranslatedSQL: - """ - Generate final metrics SQL with fully computed metric expressions. - - Metrics SQL is the second (and final) stage of metric computation - it takes - the pre-aggregated components from Measures SQL and applies combiner expressions - to produce the actual metric values requested. - - - Metric components are re-aggregated as needed to match the requested - dimensional grain. - - - Derived metrics (defined as expressions over other metrics) - (e.g., `conversion_rate = order_count / visitor_count`) are computed by - substituting component references with their re-aggregated expressions. - - - When metrics come from different fact tables, their - grain groups are FULL OUTER JOINed on the common dimensions, with COALESCE - for dimension columns to handle NULLs from non-matching rows. - - - Dimension references in metric expressions are resolved to their - final column aliases. - - Returns: - A single SQL query that: - - Defines CTEs for each grain group (pre-aggregated component data) or - uses materialized pre-agg tables when available - - Joins grain groups on shared dimensions (if multiple) - - Builds dimensions with coalesce and metrics with combiner expressions - - Groups by dimensions to finalize re-aggregation - - See also: `/sql/measures/v3/` for the underlying pre-aggregated components. - """ - - result = await build_metrics_sql( - session=session, - metrics=metrics, - dimensions=dimensions, - filters=filters, - dialect=Dialect.SPARK, - ) - - return V3TranslatedSQL( - sql=result.sql, - columns=[ - V3ColumnMetadata( - name=col.name, - type=col.type, - semantic_entity=col.semantic_name, - semantic_type=col.semantic_type, - ) - for col in result.columns - ], - dialect=result.dialect, - ) - - @router.get("/sql/", response_model=TranslatedSQL, name="Get SQL For Metrics") async def get_sql_for_metrics( metrics: List[str] = Query([]), diff --git a/datajunction-server/datajunction_server/internal/materializations.py b/datajunction-server/datajunction_server/internal/materializations.py index 48d25b247..e0dd124df 100644 --- a/datajunction-server/datajunction_server/internal/materializations.py +++ b/datajunction-server/datajunction_server/internal/materializations.py @@ -192,7 +192,7 @@ async def build_cube_materialization_config( f"node `{current_revision.name}` and job " f"`{upsert_input.job.name}` as" # type: ignore " the config does not have valid configuration for " - f"engine `{upsert_input.job.name}`." + f"engine `{upsert_input.job.name}`." # type: ignore ), ) from exc