diff --git a/poetry.lock b/poetry.lock index 18f73ff..942a151 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1574,6 +1574,28 @@ importlib-metadata = {version = ">=4.6.0", markers = "python_version < \"3.10\"" colors = ["colorama"] plugins = ["setuptools"] +[[package]] +name = "libadvian" +version = "1.10.0" +description = "Small helpers that do not warrant their own library" +optional = false +python-versions = ">=3.9,<4.0" +groups = ["main"] +files = [] +develop = false + +[package.extras] +all = ["http"] +http = ["frozendict", "requests"] +logstash = ["http"] +vector = ["http"] + +[package.source] +type = "git" +url = "https://gitlab.com/advian-oss/python-libadvian.git" +reference = "log_levels" +resolved_reference = "1002a851ba6284551132dbef075c2fa0e1ba110d" + [[package]] name = "librt" version = "0.7.2" @@ -3102,4 +3124,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">3.9.1,<4.0" -content-hash = "28fb0374c3395725d144c439017e93b9dbf4aefaeb1e7f1a113af44b45d763cc" +content-hash = "607be85a288923aad6aa1def3b28dee07c7187cecacf28fa0e653e604f00eb24" diff --git a/pyproject.toml b/pyproject.toml index 6f5a8bb..4955277 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ repository = "https://github.com/pvarki/python-libpvarki/" license = "MIT" readme = "README.rst" - [tool.black] line-length = 120 target-version = ['py38'] @@ -65,6 +64,7 @@ aiohttp = ">=3.10.2,<4.0" aiodns = "^3.0" brotli = "^1.0" cchardet = { version="^2.1", python="<=3.10"} +libadvian = { git = "https://gitlab.com/advian-oss/python-libadvian.git", branch = "log_levels" } [tool.poetry.group.dev.dependencies] diff --git a/src/libpvarki/auditlogging/README.md b/src/libpvarki/auditlogging/README.md new file mode 100644 index 0000000..eddf198 --- /dev/null +++ b/src/libpvarki/auditlogging/README.md @@ -0,0 +1,228 @@ +# libpvarki.auditlogging + +`libpvarki` module for providing structured audit logging compliant with organizational requirements. + +## Structure + +``` +auditlogging/ +├── README.md +├── src/ +│ └── libpvarki/ +│ └── auditlogging/ +│ ├── __init__.py # Public API, AUDIT level setup +│ ├── context.py # ContextVars for async-safe request context +│ ├── middleware.py # FastAPI AuditMiddleware +│ ├── helpers.py # audit_log() and convenience functions +│ ├── propagation.py # Service-to-service header propagation +│ └── py.typed # PEP 561 marker +└── tests/ + └── test_auditlogging.py +``` + +## Installation + +Copy the `auditlogging/` directory into `libpvarki/src/libpvarki/`: + +```bash +cp -r src/libpvarki/auditlogging /path/to/python-libpvarki/src/libpvarki/ +cp tests/test_auditlogging.py /path/to/python-libpvarki/tests/ +``` + +### Prerequisites + +Requires libadvian with MR #15 for native AUDIT level. Until merged: + +```toml +# pyproject.toml +[tool.poetry.dependencies] +libadvian = { git = "https://gitlab.com/advian-oss/python-libadvian.git", branch = "log_levels" } +``` + +The module includes a fallback that adds AUDIT level if libadvian doesn't have it yet. + +## Integration with Existing Stack + +``` +libadvian.logging ← MR #15 adds AUDIT level + ↓ +libpvarki.logging ← ECS formatting via ecs-logging + ↓ +libpvarki.auditlogging ← THIS MODULE + ↓ +rmapi / takrmapi / ocsprest / products +``` + +## Quick Start + +### 1. Initialize in FastAPI app + +```python +from fastapi import FastAPI +from libpvarki.auditlogging import init_audit, AuditMiddleware +import logging + +app = FastAPI() +app.add_middleware(AuditMiddleware) + +@app.on_event("startup") +async def startup(): + init_audit(logging.INFO) +``` + +### 2. Log audit events + +```python +import logging +from libpvarki.auditlogging import audit_log + +LOGGER = logging.getLogger(__name__) + +LOGGER.audit( + "Certificate issued for user", + extra=audit_log( + category="iam", + action="cert_issue", + outcome="success", + target_user="NORPPA11", + target_resource="DEADBEEF", # cert serial + ) +) +``` + +### 3. Propagate context to downstream services + +```python +from libpvarki.mtlshelp.session import get_session +from libpvarki.auditlogging import get_propagation_headers + +session = await get_session(client_cert, client_key, ca_cert) +headers = get_propagation_headers() # Includes X-Initiator-* headers +await session.post(url, json=data, headers=headers) +``` + +## Request Flow + +``` +User (NORPPA11) ─mTLS─► nginx ───► rmapi ───► takrmapi + │ │ │ + │ │ └── Sees X-Initiator-User: NORPPA11 + │ └── Extracts from X-ClientCert-DN + └── Sets X-ClientCert-DN: CN=NORPPA11 +``` + +## Header Conventions + +### nginx → service (direct mTLS) + +``` +X-Request-ID: +X-Real-IP: +X-ClientCert-DN: CN=,O=PVARKI,C=FI +X-ClientCert-Serial: +``` + +### service → service (propagation) + +``` +X-Request-ID: +X-Initiator-User: +X-Initiator-IP: +X-Initiator-Role: +X-Initiator-Cert-Serial: +``` + +## nginx Configuration + +```nginx +# In your nginx server block +proxy_set_header X-Request-ID $request_id; +proxy_set_header X-Real-IP $remote_addr; +proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; +proxy_set_header X-ClientCert-DN $ssl_client_s_dn; +proxy_set_header X-ClientCert-Serial $ssl_client_serial; +``` + +## Event Categories + +| Category | Use For | +|----------|---------| +| `authentication` | Login, logout, OTP exchange, JWT validation | +| `authorization` | Permission checks, access denied | +| `iam` | Enrollment, cert issuance, revocation | +| `configuration` | Settings changes, admin actions | +| `session` | JWT creation, refresh, expiry | +| `intrusion_detection` | Failed attempts, anomalies | + +## Convenience Functions + +```python +from libpvarki.auditlogging import ( + audit_authentication, # category="authentication" + audit_iam, # category="iam" + audit_authorization, # category="authorization" + audit_configuration, # category="configuration" + audit_session, # category="session" + audit_anomaly, # category="intrusion_detection", outcome="failure" +) + +# Examples +LOGGER.audit("Login successful", extra=audit_authentication("login", outcome="success")) +LOGGER.audit("Cert issued", extra=audit_iam("cert_issue", target_user="NORPPA11")) +LOGGER.audit("Brute force detected", extra=audit_anomaly("brute_force", error_message="5 failed attempts")) +``` + +## ECS Output Example + +With `LOG_CONSOLE_FORMATTER=ecs` (default), output is ECS-compliant JSON: + +```json +{ + "@timestamp": "2025-12-21T00:00:00.000Z", + "ecs.version": "1.6.0", + "log.level": "AUDIT", + "log.logger": "rasenmaeher_api.routes.token", + "message": "OTP exchange successful for NORPPA11", + "event.category": "authentication", + "event.action": "otp_exchange", + "event.outcome": "success", + "source.ip": "203.0.113.50", + "source.user.name": "NORPPA11", + "tls.client.x509.serial_number": "DEADBEEF", + "user.target.name": "NORPPA11", + "trace.id": "abc-123-def-456", + "service.name": "rmapi", + "service.version": "1.6.4" +} +``` + +## Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `LOG_CONSOLE_FORMATTER` | `ecs` | `ecs` for JSON, `local` for human-readable | +| `SERVICE_NAME` | hostname | Service name in logs | +| `RELEASE_TAG` | `unknown` | Service version in logs | + +## Testing + +```bash +cd python-libpvarki +pytest tests/test_auditlogging.py -v +``` + +## Migration from init_logging + +Replace `init_logging` with `init_audit` to enable AUDIT level: + +```python +# Before +from libpvarki.logging import init_logging +init_logging(logging.INFO) + +# After +from libpvarki.auditlogging import init_audit +init_audit(logging.INFO) +``` + +Or continue using `init_logging` - the AUDIT level is registered on module import. diff --git a/src/libpvarki/auditlogging/__init__.py b/src/libpvarki/auditlogging/__init__.py new file mode 100644 index 0000000..41b5f30 --- /dev/null +++ b/src/libpvarki/auditlogging/__init__.py @@ -0,0 +1,227 @@ +""" +PVARKI Audit Logging Module. + +Add-on to libpvarki.logging that provides structured audit logging with: + +- **AUDIT log level (25)** - Between INFO and WARNING +- **Request context propagation** - via ContextVars (async-safe) +- **Service-to-service propagation** - via HTTP headers +- **ECS-compliant fields** - works with existing ecs-logging formatter + +This module builds on: + +- libadvian.logging (base logging, MR #15 adds AUDIT level) +- libpvarki.logging (ECS formatting via ecs-logging) + +Quick Start +----------- +1. Initialize logging in your FastAPI app:: + + from fastapi import FastAPI + from libpvarki.auditlogging import init_audit, AuditMiddleware + import logging + + app = FastAPI() + app.add_middleware(AuditMiddleware) + + @app.on_event("startup") + async def startup(): + init_audit(logging.INFO) + +2. Log audit events in your code:: + + import logging + from libpvarki.auditlogging import audit_log + + LOGGER = logging.getLogger(__name__) + + LOGGER.audit( + "Certificate issued for user", + extra=audit_log( + category="iam", + action="cert_issue", + outcome="success", + target_user="NORPPA11", + target_resource="DEADBEEF", + ) + ) + +3. Propagate context to downstream services:: + + from libpvarki.auditlogging import get_propagation_headers + from libpvarki.mtlshelp.session import get_session + + session = await get_session(...) + headers = get_propagation_headers() + await session.post(url, json=data, headers=headers) + + +Environment Variables +--------------------- +LOG_CONSOLE_FORMATTER : str + "ecs" (default) for JSON, "local" for human-readable. + (Inherited from libpvarki.logging) +SERVICE_NAME : str + Service identifier for logs (defaults to HOSTNAME). +RELEASE_TAG : str + Service version for logs. + +Header Conventions +------------------ +nginx -> service:: + + X-Request-ID: Trace correlation ID + X-Real-IP: Client IP + X-ClientCert-DN: mTLS certificate DN + X-ClientCert-Serial: mTLS certificate serial + +service -> service:: + + X-Request-ID: Trace correlation ID + X-Initiator-User: Original user/callsign + X-Initiator-IP: Original client IP + X-Initiator-Role: User role + X-Initiator-Cert-Serial: Original cert serial + X-Initiator-Session: Session ID +""" + +import logging +from typing import Any + +# Import existing libpvarki logging (which builds on libadvian) +from libpvarki.logging import init_logging + +# Context management +from .context import ( + AuditContext, + get_audit_context, + set_audit_context, + clear_audit_context, +) + +# FastAPI middleware +from .middleware import ( + AuditMiddleware, + update_audit_user, +) + +# Logging helpers +from .helpers import ( + audit_log, + audit_extra, + audit_authentication, + audit_iam, + audit_authorization, + audit_configuration, + audit_session, + audit_anomaly, +) + +# Service-to-service propagation +from .propagation import ( + get_propagation_headers, + inject_audit_context, + AuditContextClientMixin, + create_audit_trace_config, +) + + +# ============================================================================= +# AUDIT Level Setup +# ============================================================================= + +# AUDIT log level: between INFO (20) and WARNING (30) +AUDIT = 25 + + +def _ensure_audit_level() -> None: + """ + Ensure AUDIT log level is registered with Python logging. + + This provides a fallback for libadvian versions before MR #15 is merged. + Once MR #15 is merged and released, libadvian will handle this natively. + + Safe to call multiple times - will not duplicate registration. + + The AUDIT level (25) sits between INFO (20) and WARNING (30), + making it visible at INFO level but filterable separately. + """ + if hasattr(logging, "AUDIT"): + return # Already registered by libadvian or previous call + + logging.addLevelName(AUDIT, "AUDIT") + setattr(logging, "AUDIT", AUDIT) + + def audit(self: logging.Logger, message: str, *args: Any, **kwargs: Any) -> None: + """ + Log an audit message at level 25. + + Usage:: + + LOGGER.audit("Event description", extra=audit_log(...)) + """ + if self.isEnabledFor(AUDIT): + self._log(AUDIT, message, args, **kwargs) + + logging.Logger.audit = audit # type: ignore[attr-defined] + + +def init_audit(level: int = logging.INFO) -> None: + """ + Initialize logging with AUDIT level support. + + Call this instead of ``init_logging()`` in services that need audit logging. + Sets up the AUDIT level and configures formatters via libpvarki.logging. + + Args: + level: Minimum log level. Default INFO (20) shows AUDIT (25) events. + Use logging.DEBUG (10) for verbose output. + Use logging.WARNING (30) to suppress AUDIT events. + + Example:: + + from libpvarki.auditlogging import init_audit + import logging + + # In your app startup: + init_audit(logging.INFO) + + # Or for debugging: + init_audit(logging.DEBUG) + """ + _ensure_audit_level() + init_logging(level) + + +# Ensure AUDIT level exists on module import +# This allows LOGGER.audit() to work even before init_audit() is called +_ensure_audit_level() + + +__all__ = [ + # Initialization + "init_audit", + "AUDIT", + # Context management + "AuditContext", + "get_audit_context", + "set_audit_context", + "clear_audit_context", + # Middleware + "AuditMiddleware", + "update_audit_user", + # Logging helpers + "audit_log", + "audit_extra", + "audit_authentication", + "audit_iam", + "audit_authorization", + "audit_configuration", + "audit_session", + "audit_anomaly", + # Propagation + "get_propagation_headers", + "inject_audit_context", + "AuditContextClientMixin", + "create_audit_trace_config", +] diff --git a/src/libpvarki/auditlogging/context.py b/src/libpvarki/auditlogging/context.py new file mode 100644 index 0000000..dd895e5 --- /dev/null +++ b/src/libpvarki/auditlogging/context.py @@ -0,0 +1,160 @@ +""" +Request context management using ContextVars. + +ContextVars provide task-local storage that works correctly with asyncio. +Each concurrent request gets its own isolated context automatically. + +Usage: + The AuditMiddleware sets context at request start. + The audit_log() helper reads it when logging. + Context is cleared at request end to prevent leakage. +""" + +from contextvars import ContextVar +from dataclasses import dataclass, field +from typing import Optional, Dict, Any +import uuid + + +@dataclass +class AuditContext: + """ + Container for request-scoped audit context. + + Holds initiator information extracted from incoming requests, + either from nginx headers (direct mTLS) or propagation headers + (service-to-service calls). + + Attributes: + trace_id: Correlation ID for the entire request chain. + initiator_ip: Source IP address of the original requester. + initiator_user: Username/callsign of the initiator. + initiator_role: Role of the initiator (admin, user, etc.). + initiator_cert_serial: mTLS certificate serial number. + initiator_cert_cn: mTLS certificate Common Name. + initiator_session: Session ID if applicable. + is_propagated: True if context came from upstream service headers. + """ + + # Correlation ID for request chain tracing + trace_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + # Initiator information (who caused this action) + initiator_ip: str = "" + initiator_user: str = "" + initiator_role: str = "" + initiator_cert_serial: str = "" + initiator_cert_cn: str = "" + initiator_session: str = "" + + # Metadata + is_propagated: bool = False + + def to_ecs_fields(self) -> Dict[str, Any]: + """ + Convert context to ECS-compliant field dictionary. + + Returns: + Dict with ECS field names. Empty values are excluded. + """ + result: Dict[str, Any] = { + "trace.id": self.trace_id, + } + + if self.initiator_ip: + result["source.ip"] = self.initiator_ip + if self.initiator_user: + result["source.user.name"] = self.initiator_user + if self.initiator_role: + result["source.user.roles"] = [self.initiator_role] + if self.initiator_cert_serial: + result["tls.client.x509.serial_number"] = self.initiator_cert_serial + if self.initiator_cert_cn: + result["tls.client.x509.subject.common_name"] = self.initiator_cert_cn + if self.initiator_session: + result["session.id"] = self.initiator_session + + return result + + +# Module-level ContextVar instance +# Each async task automatically gets isolated storage +_audit_context: ContextVar[AuditContext] = ContextVar("audit_context", default=AuditContext()) + + +def get_audit_context() -> AuditContext: + """ + Get the current request's audit context. + + Safe to call from anywhere. Returns empty context if called + outside of a request scope (e.g., during startup). + + Returns: + Current AuditContext for this async task. + """ + return _audit_context.get() + + +def set_audit_context( + trace_id: Optional[str] = None, + initiator_ip: Optional[str] = None, + initiator_user: Optional[str] = None, + initiator_role: Optional[str] = None, + initiator_cert_serial: Optional[str] = None, + initiator_cert_cn: Optional[str] = None, + initiator_session: Optional[str] = None, + is_propagated: Optional[bool] = None, +) -> AuditContext: + """ + Set or update the audit context for the current request. + + Only provided (non-None) fields are updated. Other fields retain + their current values. This allows incremental updates, e.g., + setting user info after JWT validation. + + Typically called by: + - AuditMiddleware at request start + - Auth dependencies after JWT validation + - Service code when additional context is available + + Args: + trace_id: Correlation ID (from X-Request-ID or generated). + initiator_ip: Source IP address. + initiator_user: Username/callsign. + initiator_role: User role (admin, user, service, etc.). + initiator_cert_serial: mTLS certificate serial number. + initiator_cert_cn: mTLS certificate Common Name. + initiator_session: Session identifier. + is_propagated: True if context came from upstream service. + + Returns: + The updated AuditContext. + """ + current = _audit_context.get() + + new_context = AuditContext( + trace_id=trace_id if trace_id is not None else current.trace_id, + initiator_ip=initiator_ip if initiator_ip is not None else current.initiator_ip, + initiator_user=initiator_user if initiator_user is not None else current.initiator_user, + initiator_role=initiator_role if initiator_role is not None else current.initiator_role, + initiator_cert_serial=( + initiator_cert_serial if initiator_cert_serial is not None else current.initiator_cert_serial + ), + initiator_cert_cn=initiator_cert_cn if initiator_cert_cn is not None else current.initiator_cert_cn, + initiator_session=initiator_session if initiator_session is not None else current.initiator_session, + is_propagated=is_propagated if is_propagated is not None else current.is_propagated, + ) + + _audit_context.set(new_context) + return new_context + + +def clear_audit_context() -> None: + """ + Reset context to empty defaults. + + Must be called at request end to prevent context leakage between + requests. The AuditMiddleware handles this automatically in its + finally block. + """ + _audit_context.set(AuditContext()) diff --git a/src/libpvarki/auditlogging/helpers.py b/src/libpvarki/auditlogging/helpers.py new file mode 100644 index 0000000..a481ff3 --- /dev/null +++ b/src/libpvarki/auditlogging/helpers.py @@ -0,0 +1,299 @@ +""" +Convenience functions for audit logging. + +These helpers format the 'extra' dict for LOGGER.audit() calls with +proper ECS field mapping and automatic context injection from ContextVars. + +The extra fields will be properly formatted by ecs-logging inlibpvarki.logging into ECS-compliant JSON output. + +Usage:: + + import logging + from libpvarki.auditlogging import audit_log + + LOGGER = logging.getLogger(__name__) + + LOGGER.audit( + "Certificate issued for user", + extra=audit_log( + category="iam", + action="cert_issue", + outcome="success", + target_user="NORPPA11", + target_resource="DEADBEEF", + target_resource_type="certificate", + ) + ) +""" + +import os +from typing import Optional, Dict, Any + +from .context import get_audit_context + + +# Service identification from environment +# These match what's typically set in PVARKI docker-compose files +SERVICE_NAME = os.getenv("SERVICE_NAME", os.getenv("HOSTNAME", "pvarki")) +SERVICE_VERSION = os.getenv("RELEASE_TAG", os.getenv("SERVICE_VERSION", "unknown")) + + +def audit_log( + category: str, + action: str, + outcome: str = "success", + # Initiator overrides (normally from context) + initiator_user: Optional[str] = None, + initiator_role: Optional[str] = None, + initiator_ip: Optional[str] = None, + initiator_cert_serial: Optional[str] = None, + # Target fields + target_user: Optional[str] = None, + target_resource: Optional[str] = None, + target_resource_type: Optional[str] = None, + # Error information + error_message: Optional[str] = None, + error_code: Optional[str] = None, + # Additional fields + **extra_fields: Any, +) -> Dict[str, Any]: + """ + Build an ECS-compliant extra dict for audit logging. + + Automatically injects initiator context from AuditMiddleware. + Use with ``LOGGER.audit("message", extra=audit_log(...))``. + + Args: + category: Event category per ECS. Common values: + + - ``authentication``: Login, logout, token exchange + - ``authorization``: Permission checks + - ``iam``: Identity management, cert issuance + - ``configuration``: Settings changes + - ``session``: Session lifecycle + - ``network``: Connection events + - ``intrusion_detection``: Security anomalies + + action: Specific action identifier. Examples: + + - ``otp_exchange``, ``jwt_validate``, ``mtls_auth`` + - ``cert_issue``, ``cert_revoke``, ``user_enroll`` + - ``config_update``, ``permission_grant`` + + outcome: Result of the action: + + - ``success``: Action completed successfully + - ``failure``: Action failed + - ``unknown``: Outcome not determined + + initiator_user: Override context initiator user. + initiator_role: Override context initiator role. + initiator_ip: Override context initiator IP. + initiator_cert_serial: Override context cert serial. + target_user: User affected by the action. + target_resource: Resource identifier (cert serial, endpoint, etc.). + target_resource_type: Type of resource (certificate, user, endpoint). + error_message: Human-readable error description for failures. + error_code: Machine-readable error code for failures. + **extra_fields: Additional fields added under ``pvarki.*`` namespace. + + Returns: + Dict suitable for logging extra parameter. + + Example:: + + LOGGER.audit( + "OTP exchange successful", + extra=audit_log( + category="authentication", + action="otp_exchange", + outcome="success", + target_user="NORPPA11", + ) + ) + """ + ctx = get_audit_context() + + # Build ECS-compliant extra dict + result: Dict[str, Any] = { + # Event classification (ECS) + "event.category": category, + "event.action": action, + "event.outcome": outcome, + # Service identification + "service.name": SERVICE_NAME, + "service.version": SERVICE_VERSION, + # Correlation + "trace.id": ctx.trace_id, + } + + # Initiator fields (explicit params override context) + _initiator_user = initiator_user or ctx.initiator_user + _initiator_role = initiator_role or ctx.initiator_role + _initiator_ip = initiator_ip or ctx.initiator_ip + _initiator_cert_serial = initiator_cert_serial or ctx.initiator_cert_serial + + if _initiator_ip: + result["source.ip"] = _initiator_ip + if _initiator_user: + result["source.user.name"] = _initiator_user + if _initiator_role: + result["source.user.roles"] = [_initiator_role] + if _initiator_cert_serial: + result["tls.client.x509.serial_number"] = _initiator_cert_serial + if ctx.initiator_cert_cn: + result["tls.client.x509.subject.common_name"] = ctx.initiator_cert_cn + if ctx.initiator_session: + result["session.id"] = ctx.initiator_session + + # Target fields (ECS user.target.* for affected user) + if target_user: + result["user.target.name"] = target_user + if target_resource: + result["pvarki.target.resource"] = target_resource + if target_resource_type: + result["pvarki.target.resource_type"] = target_resource_type + + # Error information (ECS error.*) + if error_message: + result["error.message"] = error_message + if error_code: + result["error.code"] = error_code + + # Additional fields under pvarki.* namespace + for key, value in extra_fields.items(): + if value is not None: + result[f"pvarki.{key}"] = value + + return result + + +def audit_extra(**fields: Any) -> Dict[str, Any]: + """ + Simple wrapper to add trace context to any log call. + + For non-audit logs that still need trace correlation. + Less structured than audit_log(), just adds trace.id and extra fields. + + Args: + **fields: Fields to include in the extra dict. + + Returns: + Dict with trace.id and provided fields. + + Example:: + + LOGGER.info("Processing request", extra=audit_extra( + endpoint="/api/v1/users", + method="POST", + )) + """ + ctx = get_audit_context() + result: Dict[str, Any] = { + "trace.id": ctx.trace_id, + "service.name": SERVICE_NAME, + } + result.update(fields) + return result + + +# ============================================================================= +# Convenience wrappers for common event categories +# ============================================================================= + + +def audit_authentication( + action: str, + outcome: str = "success", + target_user: Optional[str] = None, + **kwargs: Any, +) -> Dict[str, Any]: + """Build audit log extra for authentication events.""" + return audit_log( + category="authentication", + action=action, + outcome=outcome, + target_user=target_user, + **kwargs, + ) + + +def audit_iam( + action: str, + outcome: str = "success", + target_user: Optional[str] = None, + target_resource: Optional[str] = None, + **kwargs: Any, +) -> Dict[str, Any]: + """Build audit log extra for identity/access management events.""" + return audit_log( + category="iam", + action=action, + outcome=outcome, + target_user=target_user, + target_resource=target_resource, + **kwargs, + ) + + +def audit_authorization( + action: str, + outcome: str = "success", + target_user: Optional[str] = None, + **kwargs: Any, +) -> Dict[str, Any]: + """Build audit log extra for authorization events.""" + return audit_log( + category="authorization", + action=action, + outcome=outcome, + target_user=target_user, + **kwargs, + ) + + +def audit_configuration( + action: str, + outcome: str = "success", + target_resource: Optional[str] = None, + **kwargs: Any, +) -> Dict[str, Any]: + """Build audit log extra for configuration change events.""" + return audit_log( + category="configuration", + action=action, + outcome=outcome, + target_resource=target_resource, + **kwargs, + ) + + +def audit_session( + action: str, + outcome: str = "success", + target_user: Optional[str] = None, + **kwargs: Any, +) -> Dict[str, Any]: + """Build audit log extra for session lifecycle events.""" + return audit_log( + category="session", + action=action, + outcome=outcome, + target_user=target_user, + **kwargs, + ) + + +def audit_anomaly( + action: str, + error_message: Optional[str] = None, + **kwargs: Any, +) -> Dict[str, Any]: + """Build audit log extra for security anomalies (always failure).""" + return audit_log( + category="intrusion_detection", + action=action, + outcome="failure", + error_message=error_message, + **kwargs, + ) diff --git a/src/libpvarki/auditlogging/middleware.py b/src/libpvarki/auditlogging/middleware.py new file mode 100644 index 0000000..aba25b4 --- /dev/null +++ b/src/libpvarki/auditlogging/middleware.py @@ -0,0 +1,213 @@ +""" +FastAPI/Starlette middleware for automatic audit context setup. + +Extracts initiator information from incoming requests: +1. nginx headers (X-ClientCert-*, X-Real-IP) for direct mTLS requests +2. Propagation headers (X-Initiator-*) for service-to-service calls + +Sets ContextVars that are automatically read by audit_log() helper. +""" + +import logging +import uuid +from typing import Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +from .context import set_audit_context, clear_audit_context + +LOGGER = logging.getLogger(__name__) + + +# Header names for nginx mTLS info +HEADER_REQUEST_ID = "X-Request-ID" +HEADER_REAL_IP = "X-Real-IP" +HEADER_FORWARDED_FOR = "X-Forwarded-For" +HEADER_CLIENT_CERT_DN = "X-ClientCert-DN" +HEADER_CLIENT_CERT_SERIAL = "X-ClientCert-Serial" + +# Header names for service-to-service propagation +HEADER_INITIATOR_USER = "X-Initiator-User" +HEADER_INITIATOR_IP = "X-Initiator-IP" +HEADER_INITIATOR_ROLE = "X-Initiator-Role" +HEADER_INITIATOR_CERT_SERIAL = "X-Initiator-Cert-Serial" +HEADER_INITIATOR_SESSION = "X-Initiator-Session" + + +def _parse_cn_from_dn(dn: str) -> str: + """ + Extract Common Name from Distinguished Name string. + + Args: + dn: Distinguished Name, e.g., "CN=NORPPA11,O=PVARKI,C=FI" + + Returns: + The CN value, or empty string if not found. + """ + if not dn: + return "" + + for part in dn.split(","): + part = part.strip() + if part.upper().startswith("CN="): + return part[3:] + return "" + + +class AuditMiddleware(BaseHTTPMiddleware): + """ + Middleware to extract and set audit context for each request. + + Handles two scenarios: + + 1. Direct requests via nginx with mTLS: + - Reads X-ClientCert-DN, X-ClientCert-Serial from nginx + - Reads X-Real-IP or X-Forwarded-For for source IP + + 2. Service-to-service calls with propagated context: + - Reads X-Initiator-* headers set by upstream service + - Preserves original initiator identity through the chain + + Priority: Direct mTLS headers take precedence over propagated headers, + as they represent verified identity from nginx. + + nginx configuration example:: + + proxy_set_header X-Request-ID $request_id; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-ClientCert-DN $ssl_client_s_dn; + proxy_set_header X-ClientCert-Serial $ssl_client_serial; + """ + + async def dispatch(self, request: Request, call_next: Callable[[Request], Response]) -> Response: + """Extract context from headers and process request.""" + + # === Trace ID (correlation) === + trace_id = request.headers.get(HEADER_REQUEST_ID, "") + if not trace_id: + trace_id = str(uuid.uuid4()) + + # === Source IP === + source_ip = self._extract_source_ip(request) + + # === Initiator Identity === + # Try direct mTLS first (nginx headers) + cert_dn = request.headers.get(HEADER_CLIENT_CERT_DN, "") + cert_serial = request.headers.get(HEADER_CLIENT_CERT_SERIAL, "") + cert_cn = _parse_cn_from_dn(cert_dn) + + # Check for propagated context (service-to-service) + prop_user = request.headers.get(HEADER_INITIATOR_USER, "") + prop_ip = request.headers.get(HEADER_INITIATOR_IP, "") + prop_role = request.headers.get(HEADER_INITIATOR_ROLE, "") + prop_cert_serial = request.headers.get(HEADER_INITIATOR_CERT_SERIAL, "") + prop_session = request.headers.get(HEADER_INITIATOR_SESSION, "") + + # Determine final values (direct mTLS takes precedence) + is_propagated = False + if cert_cn: + # Direct mTLS request - use cert info + initiator_user = cert_cn + initiator_cert_serial = cert_serial + initiator_ip = source_ip + initiator_role = "" + initiator_session = "" + elif prop_user: + # Service-to-service with propagated context + is_propagated = True + initiator_user = prop_user + initiator_ip = prop_ip or source_ip + initiator_role = prop_role + initiator_cert_serial = prop_cert_serial + initiator_session = prop_session + else: + # No identity info - just IP + initiator_user = "" + initiator_ip = source_ip + initiator_role = "" + initiator_cert_serial = "" + initiator_session = "" + + # Set context for this request + set_audit_context( + trace_id=trace_id, + initiator_ip=initiator_ip, + initiator_user=initiator_user, + initiator_role=initiator_role, + initiator_cert_serial=initiator_cert_serial, + initiator_cert_cn=cert_cn, + initiator_session=initiator_session, + is_propagated=is_propagated, + ) + + try: + response = await call_next(request) + # Add trace ID to response for debugging/correlation + response.headers[HEADER_REQUEST_ID] = trace_id + return response + finally: + # Always clear context to prevent leakage + clear_audit_context() + + def _extract_source_ip(self, request: Request) -> str: + """ + Extract client IP from request headers. + + Priority: + 1. X-Real-IP (set by nginx) + 2. X-Forwarded-For (first IP in chain) + 3. Direct client IP from connection + + Args: + request: The incoming Starlette request. + + Returns: + Client IP address, or empty string if not available. + """ + # Prefer X-Real-IP (typically set by nginx to actual client) + real_ip = request.headers.get(HEADER_REAL_IP, "") + if real_ip: + return real_ip.strip() + + # Fall back to X-Forwarded-For (take first/leftmost = original client) + forwarded_for = request.headers.get(HEADER_FORWARDED_FOR, "") + if forwarded_for: + return forwarded_for.split(",")[0].strip() + + # Last resort: direct connection IP + if request.client: + return request.client.host + + return "" + + +def update_audit_user(user: str, role: str = "", session: str = "") -> None: + """ + Update audit context with user info after authentication. + + Call this after JWT validation or other auth mechanism in your + FastAPI dependency to enrich the audit context with user identity. + + Args: + user: Username or callsign. + role: User role (admin, user, operator, etc.). + session: Session identifier if applicable. + + Example:: + + async def get_current_user(token: str = Depends(oauth2_scheme)): + payload = decode_jwt(token) + update_audit_user( + user=payload["sub"], + role=payload.get("role", ""), + ) + return payload + """ + set_audit_context( + initiator_user=user, + initiator_role=role, + initiator_session=session, + ) diff --git a/src/libpvarki/auditlogging/propagation.py b/src/libpvarki/auditlogging/propagation.py new file mode 100644 index 0000000..9c4809d --- /dev/null +++ b/src/libpvarki/auditlogging/propagation.py @@ -0,0 +1,190 @@ +""" +Service-to-service audit context propagation. + +When one PVARKI service calls another (e.g., rmapi -> takrmapi), +the original initiator information must be passed along so audit +logs in downstream services correctly attribute actions. + +This module provides helpers to: +1. Get headers to include in outgoing HTTP requests +2. Inject context into aiohttp client sessions + +Example usage with aiohttp (already a libpvarki dependency):: + + from libpvarki.auditlogging import get_propagation_headers + import aiohttp + + async def call_product_api(url: str, data: dict): + headers = get_propagation_headers() + async with aiohttp.ClientSession() as session: + await session.post(url, json=data, headers=headers) +""" + +from typing import Dict, Optional, Any + +from .context import get_audit_context + +# Header names for propagation (must match middleware.py) +HEADER_REQUEST_ID = "X-Request-ID" +HEADER_INITIATOR_USER = "X-Initiator-User" +HEADER_INITIATOR_IP = "X-Initiator-IP" +HEADER_INITIATOR_ROLE = "X-Initiator-Role" +HEADER_INITIATOR_CERT_SERIAL = "X-Initiator-Cert-Serial" +HEADER_INITIATOR_SESSION = "X-Initiator-Session" + + +def get_propagation_headers() -> Dict[str, str]: + """ + Get HTTP headers to propagate audit context to downstream services. + + Include these headers when making HTTP requests to other PVARKI + services to preserve the initiator chain for audit logging. + + Returns: + Dict of header name -> value. Only non-empty values included. + + Example with aiohttp:: + + from libpvarki.auditlogging import get_propagation_headers + import aiohttp + + async with aiohttp.ClientSession() as session: + headers = get_propagation_headers() + await session.post(url, json=data, headers=headers) + + Example with libpvarki.mtlshelp.session:: + + from libpvarki.mtlshelp.session import get_session + from libpvarki.auditlogging import get_propagation_headers + + session = await get_session(client_cert, client_key, ca_cert) + headers = get_propagation_headers() + async with session.post(url, json=data, headers=headers) as resp: + ... + """ + ctx = get_audit_context() + headers: Dict[str, str] = {} + + # Always include trace ID for correlation + if ctx.trace_id: + headers[HEADER_REQUEST_ID] = ctx.trace_id + + # Include initiator info if available + if ctx.initiator_user: + headers[HEADER_INITIATOR_USER] = ctx.initiator_user + if ctx.initiator_ip: + headers[HEADER_INITIATOR_IP] = ctx.initiator_ip + if ctx.initiator_role: + headers[HEADER_INITIATOR_ROLE] = ctx.initiator_role + if ctx.initiator_cert_serial: + headers[HEADER_INITIATOR_CERT_SERIAL] = ctx.initiator_cert_serial + if ctx.initiator_session: + headers[HEADER_INITIATOR_SESSION] = ctx.initiator_session + + return headers + + +def inject_audit_context(headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """ + Inject audit context into an existing headers dict. + + Convenience function that merges propagation headers with + existing headers. Existing headers are NOT overwritten. + + Args: + headers: Existing headers dict, or None to create new one. + + Returns: + Headers dict with audit context added. + + Example:: + + headers = {"Content-Type": "application/json"} + headers = inject_audit_context(headers) + await session.post(url, headers=headers, json=data) + """ + result = dict(headers) if headers else {} + propagation = get_propagation_headers() + + # Add propagation headers, don't overwrite existing + for key, value in propagation.items(): + if key not in result: + result[key] = value + + return result + + +class AuditContextClientMixin: + """ + Mixin for HTTP clients that automatically propagates audit context. + + Can be used as a mixin for custom client classes. + + Example:: + + class ProductClient(AuditContextClientMixin): + def __init__(self, base_url: str): + self.base_url = base_url + + async def notify_enrollment(self, callsign: str): + headers = self.get_audit_headers() + async with aiohttp.ClientSession() as session: + await session.post( + f"{self.base_url}/api/v1/enrolled", + json={"callsign": callsign}, + headers=headers, + ) + """ + + def get_audit_headers(self) -> Dict[str, str]: + """Get headers with audit context for HTTP requests.""" + return get_propagation_headers() + + def merge_audit_headers(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: + """Merge audit headers with existing headers.""" + return inject_audit_context(headers) + + +def create_audit_trace_config() -> Any: + """ + Create aiohttp TraceConfig that adds audit headers to all requests. + + This automatically injects propagation headers into every request + made by the ClientSession. + + Returns: + aiohttp.TraceConfig instance. + + Example:: + + import aiohttp + from libpvarki.auditlogging import create_audit_trace_config + + trace_config = create_audit_trace_config() + async with aiohttp.ClientSession(trace_configs=[trace_config]) as session: + # All requests automatically include audit headers + await session.get("http://other-service/api/v1/status") + """ + try: + import aiohttp + + async def on_request_start( + session: aiohttp.ClientSession, + trace_config_ctx: Any, + params: aiohttp.TraceRequestStartParams, + ) -> None: + """Add audit headers before each request.""" + headers = get_propagation_headers() + for key, value in headers.items(): + if key not in params.headers: + params.headers[key] = value + + trace_config = aiohttp.TraceConfig() + trace_config.on_request_start.append(on_request_start) + return trace_config + + except ImportError: + raise ImportError( + "aiohttp is required for create_audit_trace_config(). " + "This should already be installed as a libpvarki dependency." + ) diff --git a/src/libpvarki/logging.py b/src/libpvarki/logging.py index 02ead3d..7ca3bed 100644 --- a/src/libpvarki/logging.py +++ b/src/libpvarki/logging.py @@ -74,9 +74,11 @@ def __init__(self, extras: Mapping[str, Any], name: str = "") -> None: def filter(self, record: logging.LogRecord) -> bool: """Add the extras then call parent filter""" - for key in self.add_extras: - setattr(record, key, self.add_extras[key]) - return super().filter(record) + for key, value in self.add_extras.items(): + setattr(record, key, value) + + result = super().filter(record) + return bool(result) def init_logging(level: int = logging.INFO) -> None: diff --git a/tests/test_auditlogging.py b/tests/test_auditlogging.py new file mode 100644 index 0000000..106efc2 --- /dev/null +++ b/tests/test_auditlogging.py @@ -0,0 +1,478 @@ +""" +Tests for libpvarki.auditlogging module. + +Run with: pytest tests/test_auditlogging.py -v + +These tests are designed to work within the libpvarki test infrastructure. +""" + +import logging +from typing import Generator, Dict, Any + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from libpvarki.auditlogging import ( + init_audit, + AUDIT, + AuditMiddleware, + audit_log, + audit_extra, + audit_authentication, + audit_iam, + get_audit_context, + set_audit_context, + clear_audit_context, + get_propagation_headers, + inject_audit_context, + update_audit_user, +) + + +# ============================================================================= +# Fixtures +# ============================================================================= + + +@pytest.fixture(autouse=True) +def reset_context() -> Generator[None, None, None]: + """Clear audit context before and after each test.""" + clear_audit_context() + yield + clear_audit_context() + + +@pytest.fixture +def app() -> FastAPI: + """Create a test FastAPI app with AuditMiddleware.""" + app = FastAPI() + app.add_middleware(AuditMiddleware) + + @app.get("/test") + async def test_endpoint() -> Dict[str, Any]: + ctx = get_audit_context() + return { + "trace_id": ctx.trace_id, + "initiator_user": ctx.initiator_user, + "initiator_ip": ctx.initiator_ip, + "is_propagated": ctx.is_propagated, + } + + @app.post("/enroll") + async def enroll_endpoint() -> Dict[str, str]: + ctx = get_audit_context() + return {"enrolled": ctx.initiator_user} + + return app + + +@pytest.fixture +def client(app: FastAPI) -> TestClient: + """Test client for the app.""" + return TestClient(app) + + +# ============================================================================= +# AUDIT Level Tests +# ============================================================================= + + +class TestAuditLevel: + """Tests for AUDIT log level setup.""" + + def test_audit_level_constant(self) -> None: + """AUDIT level should be 25.""" + assert AUDIT == 25 + + def test_audit_level_registered(self) -> None: + """AUDIT level should be registered with logging module.""" + assert hasattr(logging, "AUDIT") + assert logging.AUDIT == 25 # type: ignore[attr-defined] + + def test_logger_has_audit_method(self) -> None: + """Logger instances should have audit() method.""" + logger = logging.getLogger("test.audit_level") + assert hasattr(logger, "audit") + assert callable(logger.audit) + + def test_audit_level_name(self) -> None: + """AUDIT level should have correct name.""" + assert logging.getLevelName(25) == "AUDIT" + assert logging.getLevelName("AUDIT") == 25 + + +# ============================================================================= +# Context Tests +# ============================================================================= + + +class TestAuditContext: + """Tests for ContextVar-based audit context.""" + + def test_default_context(self) -> None: + """Default context should have generated trace_id.""" + ctx = get_audit_context() + assert ctx.trace_id # Should be non-empty UUID + assert ctx.initiator_user == "" + assert ctx.initiator_ip == "" + assert ctx.is_propagated is False + + def test_set_context(self) -> None: + """set_audit_context should update fields.""" + set_audit_context( + trace_id="test-trace-123", + initiator_user="NORPPA11", + initiator_ip="192.168.1.100", + ) + ctx = get_audit_context() + assert ctx.trace_id == "test-trace-123" + assert ctx.initiator_user == "NORPPA11" + assert ctx.initiator_ip == "192.168.1.100" + + def test_partial_update(self) -> None: + """set_audit_context should only update provided fields.""" + set_audit_context(initiator_user="KOTKA1") + set_audit_context(initiator_role="admin") + + ctx = get_audit_context() + assert ctx.initiator_user == "KOTKA1" + assert ctx.initiator_role == "admin" + + def test_clear_context(self) -> None: + """clear_audit_context should reset to defaults.""" + set_audit_context(initiator_user="NORPPA11") + clear_audit_context() + + ctx = get_audit_context() + assert ctx.initiator_user == "" + + def test_to_ecs_fields(self) -> None: + """Context should convert to ECS field dict.""" + set_audit_context( + trace_id="abc-123", + initiator_user="NORPPA11", + initiator_ip="10.0.0.1", + initiator_role="operator", + initiator_cert_serial="DEADBEEF", + ) + ctx = get_audit_context() + fields = ctx.to_ecs_fields() + + assert fields["trace.id"] == "abc-123" + assert fields["source.user.name"] == "NORPPA11" + assert fields["source.ip"] == "10.0.0.1" + assert fields["source.user.roles"] == ["operator"] + assert fields["tls.client.x509.serial_number"] == "DEADBEEF" + + +# ============================================================================= +# Middleware Tests +# ============================================================================= + + +class TestAuditMiddleware: + """Tests for FastAPI middleware.""" + + def test_extracts_request_id(self, client: TestClient) -> None: + """Middleware should extract X-Request-ID.""" + response = client.get("/test", headers={"X-Request-ID": "my-trace-id"}) + assert response.status_code == 200 + assert response.json()["trace_id"] == "my-trace-id" + + def test_generates_request_id(self, client: TestClient) -> None: + """Middleware should generate trace ID if not provided.""" + response = client.get("/test") + assert response.status_code == 200 + assert response.json()["trace_id"] # Non-empty + + def test_returns_request_id_header(self, client: TestClient) -> None: + """Response should include X-Request-ID header.""" + response = client.get("/test", headers={"X-Request-ID": "echo-me"}) + assert response.headers["X-Request-ID"] == "echo-me" + + def test_extracts_real_ip(self, client: TestClient) -> None: + """Middleware should extract X-Real-IP.""" + response = client.get("/test", headers={"X-Real-IP": "203.0.113.50"}) + assert response.json()["initiator_ip"] == "203.0.113.50" + + def test_extracts_forwarded_for(self, client: TestClient) -> None: + """Middleware should extract X-Forwarded-For if no X-Real-IP.""" + response = client.get("/test", headers={"X-Forwarded-For": "203.0.113.50, 10.0.0.1"}) + assert response.json()["initiator_ip"] == "203.0.113.50" + + def test_extracts_cert_dn(self, client: TestClient) -> None: + """Middleware should extract CN from X-ClientCert-DN.""" + response = client.get( + "/test", + headers={"X-ClientCert-DN": "CN=NORPPA11,O=PVARKI,C=FI"}, + ) + assert response.json()["initiator_user"] == "NORPPA11" + + def test_propagated_headers(self, client: TestClient) -> None: + """Middleware should extract X-Initiator-* headers.""" + response = client.get( + "/test", + headers={ + "X-Initiator-User": "KOTKA1", + "X-Initiator-IP": "192.168.1.50", + }, + ) + data = response.json() + assert data["initiator_user"] == "KOTKA1" + assert data["initiator_ip"] == "192.168.1.50" + assert data["is_propagated"] is True + + def test_direct_mtls_takes_precedence(self, client: TestClient) -> None: + """Direct mTLS cert should override propagated headers.""" + response = client.get( + "/test", + headers={ + "X-ClientCert-DN": "CN=DIRECT_USER,O=TEST", + "X-Initiator-User": "PROPAGATED_USER", + }, + ) + # Direct mTLS wins + assert response.json()["initiator_user"] == "DIRECT_USER" + assert response.json()["is_propagated"] is False + + +# ============================================================================= +# Helper Tests +# ============================================================================= + + +class TestAuditLogHelper: + """Tests for audit_log() helper function.""" + + def test_basic_audit_log(self) -> None: + """audit_log should create ECS-compliant dict.""" + set_audit_context(trace_id="test-123") + + extra = audit_log( + category="authentication", + action="otp_exchange", + outcome="success", + ) + + assert extra["event.category"] == "authentication" + assert extra["event.action"] == "otp_exchange" + assert extra["event.outcome"] == "success" + assert extra["trace.id"] == "test-123" + assert "service.name" in extra + + def test_audit_log_with_target(self) -> None: + """audit_log should include target fields.""" + extra = audit_log( + category="iam", + action="cert_issue", + outcome="success", + target_user="NORPPA11", + target_resource="DEADBEEF", + target_resource_type="certificate", + ) + + assert extra["user.target.name"] == "NORPPA11" + assert extra["pvarki.target.resource"] == "DEADBEEF" + assert extra["pvarki.target.resource_type"] == "certificate" + + def test_audit_log_with_error(self) -> None: + """audit_log should include error fields.""" + extra = audit_log( + category="authentication", + action="jwt_validate", + outcome="failure", + error_message="Token expired", + error_code="TOKEN_EXPIRED", + ) + + assert extra["event.outcome"] == "failure" + assert extra["error.message"] == "Token expired" + assert extra["error.code"] == "TOKEN_EXPIRED" + + def test_audit_log_uses_context(self) -> None: + """audit_log should include context initiator.""" + set_audit_context( + initiator_user="CONTEXT_USER", + initiator_ip="10.0.0.1", + ) + + extra = audit_log(category="test", action="test") + + assert extra["source.user.name"] == "CONTEXT_USER" + assert extra["source.ip"] == "10.0.0.1" + + def test_audit_log_override_context(self) -> None: + """Explicit params should override context.""" + set_audit_context(initiator_user="CONTEXT_USER") + + extra = audit_log( + category="test", + action="test", + initiator_user="OVERRIDE_USER", + ) + + assert extra["source.user.name"] == "OVERRIDE_USER" + + def test_audit_log_extra_fields(self) -> None: + """Extra fields should go under pvarki namespace.""" + extra = audit_log( + category="test", + action="test", + custom_field="custom_value", + another_field=123, + ) + + assert extra["pvarki.custom_field"] == "custom_value" + assert extra["pvarki.another_field"] == 123 + + def test_convenience_wrappers(self) -> None: + """Category convenience functions should work.""" + auth = audit_authentication("login", outcome="success") + assert auth["event.category"] == "authentication" + + iam = audit_iam("cert_issue", target_user="NORPPA11") + assert iam["event.category"] == "iam" + assert iam["user.target.name"] == "NORPPA11" + + +# ============================================================================= +# Propagation Tests +# ============================================================================= + + +class TestPropagation: + """Tests for service-to-service propagation.""" + + def test_get_propagation_headers(self) -> None: + """get_propagation_headers should return context as headers.""" + set_audit_context( + trace_id="prop-trace-123", + initiator_user="NORPPA11", + initiator_ip="192.168.1.100", + initiator_role="admin", + initiator_cert_serial="DEADBEEF", + ) + + headers = get_propagation_headers() + + assert headers["X-Request-ID"] == "prop-trace-123" + assert headers["X-Initiator-User"] == "NORPPA11" + assert headers["X-Initiator-IP"] == "192.168.1.100" + assert headers["X-Initiator-Role"] == "admin" + assert headers["X-Initiator-Cert-Serial"] == "DEADBEEF" + + def test_propagation_empty_context(self) -> None: + """Propagation headers should exclude empty values.""" + clear_audit_context() + ctx = get_audit_context() + + headers = get_propagation_headers() + + assert "X-Request-ID" in headers + assert "X-Initiator-User" not in headers + assert "X-Initiator-IP" not in headers + + def test_inject_audit_context(self) -> None: + """inject_audit_context should merge with existing headers.""" + set_audit_context(trace_id="inject-test") + + existing = {"Content-Type": "application/json"} + result = inject_audit_context(existing) + + assert result["Content-Type"] == "application/json" + assert result["X-Request-ID"] == "inject-test" + + def test_inject_no_overwrite(self) -> None: + """inject_audit_context should not overwrite existing headers.""" + set_audit_context(trace_id="new-value") + + existing = {"X-Request-ID": "existing-value"} + result = inject_audit_context(existing) + + assert result["X-Request-ID"] == "existing-value" + + +# ============================================================================= +# Update User Tests +# ============================================================================= + + +class TestUpdateAuditUser: + """Tests for update_audit_user helper.""" + + def test_update_user(self) -> None: + """update_audit_user should update context.""" + update_audit_user(user="NORPPA11", role="operator") + + ctx = get_audit_context() + assert ctx.initiator_user == "NORPPA11" + assert ctx.initiator_role == "operator" + + def test_update_preserves_other_fields(self) -> None: + """update_audit_user should preserve other context fields.""" + set_audit_context(trace_id="keep-me", initiator_ip="10.0.0.1") + update_audit_user(user="NORPPA11") + + ctx = get_audit_context() + assert ctx.trace_id == "keep-me" + assert ctx.initiator_ip == "10.0.0.1" + assert ctx.initiator_user == "NORPPA11" + + +# ============================================================================= +# Integration Test +# ============================================================================= + + +class TestIntegration: + """Integration test simulating real usage.""" + + def test_full_flow(self, client: TestClient) -> None: + """Test complete audit logging flow.""" + # Simulate mTLS request through nginx + response = client.post( + "/enroll", + headers={ + "X-Request-ID": "integration-test-123", + "X-ClientCert-DN": "CN=NORPPA11,O=PVARKI,C=FI", + "X-ClientCert-Serial": "DEADBEEF", + "X-Real-IP": "203.0.113.50", + }, + ) + + assert response.status_code == 200 + assert response.json()["enrolled"] == "NORPPA11" + assert response.headers["X-Request-ID"] == "integration-test-123" + + def test_service_chain(self, client: TestClient) -> None: + """Test context propagation through service chain.""" + # First service receives from nginx + response1 = client.get( + "/test", + headers={ + "X-Request-ID": "chain-trace-456", + "X-ClientCert-DN": "CN=ORIGINAL_USER,O=TEST", + }, + ) + assert response1.json()["initiator_user"] == "ORIGINAL_USER" + + # Second service receives propagated context + response2 = client.get( + "/test", + headers={ + "X-Request-ID": "chain-trace-456", + "X-Initiator-User": "ORIGINAL_USER", + "X-Initiator-IP": "10.0.0.1", + }, + ) + assert response2.json()["initiator_user"] == "ORIGINAL_USER" + assert response2.json()["is_propagated"] is True + + +# ============================================================================= +# Run tests +# ============================================================================= + +if __name__ == "__main__": + pytest.main([__file__, "-v"])