diff --git a/msal/__init__.py b/msal/__init__.py index 295e9756..81763bb1 100644 --- a/msal/__init__.py +++ b/msal/__init__.py @@ -38,6 +38,7 @@ SystemAssignedManagedIdentity, UserAssignedManagedIdentity, ManagedIdentityClient, ManagedIdentityError, + MsiV2Error, ArcPlatformNotSupportedError, ) diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 422b76e3..85d77a8b 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -24,6 +24,11 @@ class ManagedIdentityError(ValueError): pass +class MsiV2Error(ManagedIdentityError): + """Raised when the MSI v2 (mTLS PoP) flow fails.""" + pass + + class ManagedIdentity(UserDict): """Feed an instance of this class to :class:`msal.ManagedIdentityClient` to acquire token for the specified managed identity. @@ -259,6 +264,8 @@ def acquire_token_for_client( *, resource: str, # If/when we support scope, resource will become optional claims_challenge: Optional[str] = None, + mtls_proof_of_possession: bool = False, + with_attestation_support: bool = False, ): """Acquire token for the managed identity. @@ -278,6 +285,23 @@ def acquire_token_for_client( even if the app developer did not opt in for the "CP1" client capability. Upon receiving a `claims_challenge`, MSAL will attempt to acquire a new token. + :param bool mtls_proof_of_possession: (optional) + When True, use the MSI v2 (mTLS Proof-of-Possession) flow to acquire an + ``mtls_pop`` token bound to a short-lived mTLS certificate issued by the + IMDS ``/issuecredential`` endpoint. + Without this flag the legacy IMDS v1 flow is used. + Defaults to False. + + MSI v2 is used only when both ``mtls_proof_of_possession`` and + ``with_attestation_support`` are True. + + :param bool with_attestation_support: (optional) + When True (and ``mtls_proof_of_possession`` is also True), attempt + KeyGuard / platform attestation before credential issuance. + On Windows this leverages ``AttestationClientLib.dll`` when available; + on other platforms the parameter is silently ignored. + Defaults to False. + .. note:: Known issue: When an Azure VM has only one user-assigned managed identity, @@ -292,6 +316,27 @@ def acquire_token_for_client( client_id_in_cache = self._managed_identity.get( ManagedIdentity.ID, "SYSTEM_ASSIGNED_MANAGED_IDENTITY") now = time.time() + # MSI v2 is opt-in: use it only when BOTH mtls_proof_of_possession and + # with_attestation_support are explicitly requested by the caller. + # No auto-fallback: if MSI v2 is requested and fails, the error is raised. + use_msi_v2 = bool(mtls_proof_of_possession and with_attestation_support) + + if with_attestation_support and not mtls_proof_of_possession: + raise ManagedIdentityError( + "attestation_requires_pop", + "with_attestation_support=True requires mtls_proof_of_possession=True (mTLS PoP)." + ) + + if use_msi_v2: + from .msi_v2 import obtain_token as _obtain_token_v2 + result = _obtain_token_v2( + self._http_client, self._managed_identity, resource, + attestation_enabled=True, + ) + if "access_token" in result and "error" not in result: + result[self._TOKEN_SOURCE] = self._TOKEN_SOURCE_IDP + return result + if True: # Attempt cache search even if receiving claims_challenge, # because we want to locate the existing token (if any) and refresh it matches = self._token_cache.search( diff --git a/msal/msi_v2.py b/msal/msi_v2.py new file mode 100644 index 00000000..efd29c9e --- /dev/null +++ b/msal/msi_v2.py @@ -0,0 +1,1314 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +MSI v2 (IMDSv2) Managed Identity flow — Windows KeyGuard + Attestation + SChannel mTLS PoP. + +This module implements the MSI v2 token acquisition path using Windows native APIs via ctypes: + - CNG/NCrypt: create/open a KeyGuard-protected per-boot RSA key (non-exportable) + - Minimal DER/PKCS#10: build a CSR signed with RSA-PSS/SHA256 + - IMDS: call getplatformmetadata + issuecredential + - Crypt32: bind the issued certificate to the CNG private key + - WinHTTP/SChannel: acquire access token over mTLS (token_type=mtls_pop) + +Key behavior: + - Uses a *named per-boot key*: opens the key if it already exists for this boot; otherwise creates it. + - No MSI v1 fallback: any MSI v2 failure raises MsiV2Error. + - Production-ready handle management: all WinHTTP / Crypt32 / NCrypt handles are released. + - Returns certificate with token for mTLS with resource. + +Environment variables (optional): + - AZURE_POD_IDENTITY_AUTHORITY_HOST: override IMDS base URL (default http://169.254.169.254) + - MSAL_MSI_V2_KEY_NAME: override the per-boot key name (otherwise derived from metadata clientId) + - MSAL_MSI_V2_ATTESTATION_CACHE: "0" to disable MAA JWT caching (implemented in msi_v2_attestation.py) +""" + +from __future__ import annotations + +import base64 +import hashlib +import json +import logging +import os +import sys +import uuid +from typing import Any, Dict, Optional, Tuple, List + +logger = logging.getLogger(__name__) + +# ---------------------------- +# IMDS constants +# ---------------------------- + +_IMDS_DEFAULT_BASE = "http://169.254.169.254" +_IMDS_BASE_ENVVAR = "AZURE_POD_IDENTITY_AUTHORITY_HOST" + +_API_VERSION_QUERY_PARAM = "cred-api-version" +_IMDS_V2_API_VERSION = "2.0" + +_CSR_METADATA_PATH = "/metadata/identity/getplatformmetadata" +_ISSUE_CREDENTIAL_PATH = "/metadata/identity/issuecredential" +_ACQUIRE_ENTRA_TOKEN_PATH = "/oauth2/v2.0/token" + +_CU_ID_OID_STR = "1.3.6.1.4.1.311.90.2.10" + +# ---------------------------- +# NCrypt/CNG flags +# ---------------------------- + +# ncrypt.h flags for KeyGuard protection +_NCRYPT_USE_VIRTUAL_ISOLATION_FLAG = 0x00020000 +_NCRYPT_USE_PER_BOOT_KEY_FLAG = 0x00040000 + +_RSA_KEY_SIZE = 2048 + +# Legacy KeySpec values used with NCryptCreatePersistedKey.dwLegacyKeySpec and +# CRYPT_KEY_PROV_INFO.dwKeySpec when dwProvType==0 (CNG/KSP). +_AT_KEYEXCHANGE = 1 +_AT_SIGNATURE = 2 + +# Flags used by CRYPT_KEY_PROV_INFO.dwFlags for CNG keys (best-effort: no UI prompts) +_NCRYPT_SILENT_FLAG = 0x40 + +_KEY_NAME_ENVVAR = "MSAL_MSI_V2_KEY_NAME" + +# ---------------------------- +# Error codes +# ---------------------------- + +# Common NCRYPT "not found" style statuses (NTE_*). +_NTE_BAD_KEYSET = 0x80090016 +_NTE_NO_KEY = 0x8009000D +_NTE_NOT_FOUND = 0x80090011 +_NTE_EXISTS = 0x8009000F + +# Lazy-loaded Win32 API cache +_WIN32: Optional[Dict[str, Any]] = None + + +# ---------------------------- +# Compatibility helpers (tests + cross-language parity) +# ---------------------------- + +def get_cert_thumbprint_sha256(cert_pem: str) -> str: + """ + Return base64url(SHA256(der(cert))) without padding, for cnf.x5t#S256 comparisons. + Accepts a PEM certificate string. + + Args: + cert_pem: PEM-formatted certificate string + + Returns: + Base64url-encoded SHA256 thumbprint without padding, or empty string on error + """ + try: + from cryptography import x509 + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + cert = x509.load_pem_x509_certificate(cert_pem.encode("utf-8"), default_backend()) + der = cert.public_bytes(serialization.Encoding.DER) + digest = hashlib.sha256(der).digest() + return base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii") + except Exception: + # If cryptography isn't available, fail closed (binding cannot be verified) + return "" + + +def verify_cnf_binding(token: str, cert_pem: str) -> bool: + """ + Verify that JWT payload contains cnf.x5t#S256 matching the cert thumbprint. + + Args: + token: Access token (JWT format) + cert_pem: PEM-formatted certificate + + Returns: + True if token's cnf.x5t#S256 claim matches certificate thumbprint + """ + try: + parts = token.split(".") + if len(parts) != 3: + return False + + payload_b64 = parts[1] + payload_b64 += "=" * ((4 - len(payload_b64) % 4) % 4) + claims = json.loads(base64.urlsafe_b64decode(payload_b64.encode("ascii"))) + + cnf = claims.get("cnf", {}) if isinstance(claims, dict) else {} + token_x5t = cnf.get("x5t#S256") + if not token_x5t: + return False + + cert_x5t = get_cert_thumbprint_sha256(cert_pem) + if not cert_x5t: + return False + + return token_x5t == cert_x5t + except Exception: + return False + + +def _der_to_pem(der_bytes: bytes) -> str: + """ + Convert DER certificate bytes to PEM string format. + + Args: + der_bytes: DER-encoded certificate bytes + + Returns: + PEM-formatted certificate string + """ + b64 = base64.b64encode(der_bytes).decode("ascii") + # PEM line wrapping at 64 characters + lines = [b64[i:i+64] for i in range(0, len(b64), 64)] + return "-----BEGIN CERTIFICATE-----\n" + "\n".join(lines) + "\n-----END CERTIFICATE-----" + + +# ---------------------------- +# IMDS helpers +# ---------------------------- + +def _imds_base() -> str: + """Get IMDS base URL from environment or use default.""" + return os.getenv(_IMDS_BASE_ENVVAR, _IMDS_DEFAULT_BASE).strip().rstrip("/") + + +def _new_correlation_id() -> str: + """Generate a new correlation ID for request tracing.""" + return str(uuid.uuid4()) + + +def _imds_headers(correlation_id: Optional[str] = None) -> Dict[str, str]: + """Build IMDS request headers with Metadata: true and correlation ID.""" + return {"Metadata": "true", "x-ms-client-request-id": correlation_id or _new_correlation_id()} + + +def _resource_to_scope(resource_or_scope: str) -> str: + """ + Normalize resource to scope format (append /.default if needed). + + Args: + resource_or_scope: Resource URI or scope string + + Returns: + Normalized scope string ending with /.default + """ + s = (resource_or_scope or "").strip() + if not s: + raise ValueError("resource must be non-empty") + if s.endswith("/.default"): + return s + return s.rstrip("/") + "/.default" + + +def _der_utf8string(value: str) -> bytes: + """DER UTF8String encoder (tag 0x0C).""" + raw = value.encode("utf-8") + n = len(raw) + if n < 0x80: + len_bytes = bytes([n]) + else: + tmp = bytearray() + m = n + while m > 0: + tmp.insert(0, m & 0xFF) + m >>= 8 + len_bytes = bytes([0x80 | len(tmp)]) + bytes(tmp) + return bytes([0x0C]) + len_bytes + raw + + +def _json_loads(text: str, what: str) -> Dict[str, Any]: + """Parse JSON with error context.""" + from .managed_identity import MsiV2Error + try: + obj = json.loads(text) + if not isinstance(obj, dict): + raise TypeError("expected JSON object") + return obj + except Exception as exc: + raise MsiV2Error(f"[msi_v2] Invalid JSON from {what}: {text!r}") from exc + + +def _get_first(obj: Dict[str, Any], *names: str) -> Optional[str]: + """ + Get first non-empty value from object by multiple name variants (case-insensitive). + """ + for n in names: + if n in obj and obj[n] is not None and str(obj[n]).strip() != "": + return str(obj[n]) + lower = {str(k).lower(): k for k in obj.keys()} + for n in names: + k = lower.get(n.lower()) + if k and obj[k] is not None and str(obj[k]).strip() != "": + return str(obj[k]) + return None + + +def _mi_query_params(managed_identity: Optional[Dict[str, Any]]) -> Dict[str, str]: + """ + Adds cred-api-version=2.0 plus optional UAMI selector params. + managed_identity shape (MSAL python): {"ManagedIdentityIdType": "...", "Id": "..."} + """ + params: Dict[str, str] = {_API_VERSION_QUERY_PARAM: _IMDS_V2_API_VERSION} + if not isinstance(managed_identity, dict): + return params + + id_type = managed_identity.get("ManagedIdentityIdType") + identifier = managed_identity.get("Id") + + mapping = {"ClientId": "client_id", "ObjectId": "object_id", "ResourceId": "msi_res_id"} + wire = mapping.get(id_type) + if wire and identifier: + params[wire] = str(identifier) + return params + + +def _imds_get_json(http_client, url: str, params: Dict[str, str], headers: Dict[str, str]) -> Dict[str, Any]: + """GET request to IMDS with server verification.""" + from .managed_identity import MsiV2Error + resp = http_client.get(url, params=params, headers=headers) + server = (resp.headers or {}).get("server", "") + if "imds" not in str(server).lower(): + raise MsiV2Error(f"[msi_v2] IMDS server header check failed. server={server!r} url={url}") + if resp.status_code != 200: + raise MsiV2Error(f"[msi_v2] IMDSv2 GET {url} failed: HTTP {resp.status_code}: {resp.text}") + return _json_loads(resp.text, f"GET {url}") + + +def _imds_post_json(http_client, url: str, params: Dict[str, str], headers: Dict[str, str], body: Dict[str, Any]) -> Dict[str, Any]: + """POST request to IMDS with server verification.""" + from .managed_identity import MsiV2Error + resp = http_client.post(url, params=params, headers=headers, data=json.dumps(body, separators=(",", ":"))) + server = (resp.headers or {}).get("server", "") + if "imds" not in str(server).lower(): + raise MsiV2Error(f"[msi_v2] IMDS server header check failed. server={server!r} url={url}") + if resp.status_code != 200: + raise MsiV2Error(f"[msi_v2] IMDSv2 POST {url} failed: HTTP {resp.status_code}: {resp.text}") + return _json_loads(resp.text, f"POST {url}") + + +def _token_endpoint_from_credential(cred: Dict[str, Any]) -> str: + """ + Extract token endpoint from issuecredential response. + Prefers explicit token_endpoint, falls back to mtls_authentication_endpoint + tenant_id. + """ + token_endpoint = _get_first(cred, "token_endpoint", "tokenEndpoint") + if token_endpoint: + return token_endpoint + + mtls_auth = _get_first(cred, "mtls_authentication_endpoint", "mtlsAuthenticationEndpoint", "mtls_authenticationEndpoint") + tenant_id = _get_first(cred, "tenant_id", "tenantId") + if not mtls_auth or not tenant_id: + from .managed_identity import MsiV2Error + raise MsiV2Error(f"[msi_v2] issuecredential missing mtls_authentication_endpoint/tenant_id: {cred}") + + base = mtls_auth.rstrip("/") + "/" + tenant_id.strip("/") + return base + _ACQUIRE_ENTRA_TOKEN_PATH + + +# ---------------------------- +# Win32 primitives (ctypes) +# ---------------------------- + +def _load_win32() -> Dict[str, Any]: + """ + Lazy-load Win32 APIs via ctypes (safe to import on non-Windows). + + Initializes: + - ncrypt.dll: CNG/NCrypt key management + - crypt32.dll: X.509 certificate handling + - winhttp.dll: HTTP client with SChannel support + + Returns: + Dictionary mapping API names to ctypes objects + """ + global _WIN32 + + from .managed_identity import MsiV2Error + + if _WIN32 is not None: + return _WIN32 + + if sys.platform != "win32": + raise MsiV2Error("[msi_v2] KeyGuard + attested mTLS PoP is Windows-only.") + + import ctypes + from ctypes import wintypes + + ncrypt = ctypes.WinDLL("ncrypt.dll") + crypt32 = ctypes.WinDLL("crypt32.dll", use_last_error=True) + winhttp = ctypes.WinDLL("winhttp.dll", use_last_error=True) + + # Types + NCRYPT_PROV_HANDLE = ctypes.c_void_p + NCRYPT_KEY_HANDLE = ctypes.c_void_p + SECURITY_STATUS = ctypes.c_long + + class CERT_CONTEXT(ctypes.Structure): + _fields_ = [ + ("dwCertEncodingType", wintypes.DWORD), + ("pbCertEncoded", ctypes.POINTER(ctypes.c_ubyte)), + ("cbCertEncoded", wintypes.DWORD), + ("pCertInfo", ctypes.c_void_p), + ("hCertStore", ctypes.c_void_p), + ] + + PCCERT_CONTEXT = ctypes.POINTER(CERT_CONTEXT) + + class BCRYPT_PSS_PADDING_INFO(ctypes.Structure): + _fields_ = [ + ("pszAlgId", ctypes.c_wchar_p), + ("cbSalt", wintypes.ULONG), + ] + + # Constants + ERROR_SUCCESS = 0 + + # NCRYPT constants + NCRYPT_OVERWRITE_KEY_FLAG = 0x00000080 + NCRYPT_LENGTH_PROPERTY = "Length" + NCRYPT_EXPORT_POLICY_PROPERTY = "Export Policy" + NCRYPT_KEY_USAGE_PROPERTY = "Key Usage" + NCRYPT_ALLOW_SIGNING_FLAG = 0x00000002 + NCRYPT_ALLOW_DECRYPT_FLAG = 0x00000001 + BCRYPT_PAD_PSS = 0x00000008 + BCRYPT_SHA256_ALGORITHM = "SHA256" + BCRYPT_RSA_ALGORITHM = "RSA" + BCRYPT_RSAPUBLIC_BLOB = "RSAPUBLICBLOB" + BCRYPT_RSAPUBLIC_MAGIC = 0x31415352 # 'RSA1' + + # Crypt32 constants + X509_ASN_ENCODING = 0x00000001 + PKCS_7_ASN_ENCODING = 0x00010000 + CERT_NCRYPT_KEY_HANDLE_PROP_ID = 78 + CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG = 0x40000000 + + # WinHTTP constants + WINHTTP_ACCESS_TYPE_DEFAULT_PROXY = 0 + WINHTTP_FLAG_SECURE = 0x00800000 + WINHTTP_OPTION_CLIENT_CERT_CONTEXT = 47 + WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT = 161 + WINHTTP_QUERY_STATUS_CODE = 19 + WINHTTP_QUERY_FLAG_NUMBER = 0x20000000 + + # NCrypt prototypes + ncrypt.NCryptOpenStorageProvider.argtypes = [ctypes.POINTER(NCRYPT_PROV_HANDLE), ctypes.c_wchar_p, wintypes.DWORD] + ncrypt.NCryptOpenStorageProvider.restype = SECURITY_STATUS + + ncrypt.NCryptOpenKey.argtypes = [ + NCRYPT_PROV_HANDLE, + ctypes.POINTER(NCRYPT_KEY_HANDLE), + ctypes.c_wchar_p, + wintypes.DWORD, + wintypes.DWORD, + ] + ncrypt.NCryptOpenKey.restype = SECURITY_STATUS + + ncrypt.NCryptCreatePersistedKey.argtypes = [ + NCRYPT_PROV_HANDLE, + ctypes.POINTER(NCRYPT_KEY_HANDLE), + ctypes.c_wchar_p, + ctypes.c_wchar_p, + wintypes.DWORD, + wintypes.DWORD, + ] + ncrypt.NCryptCreatePersistedKey.restype = SECURITY_STATUS + + ncrypt.NCryptSetProperty.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD, wintypes.DWORD] + ncrypt.NCryptSetProperty.restype = SECURITY_STATUS + + ncrypt.NCryptFinalizeKey.argtypes = [NCRYPT_KEY_HANDLE, wintypes.DWORD] + ncrypt.NCryptFinalizeKey.restype = SECURITY_STATUS + + ncrypt.NCryptGetProperty.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD, ctypes.POINTER(wintypes.DWORD), wintypes.DWORD] + ncrypt.NCryptGetProperty.restype = SECURITY_STATUS + + ncrypt.NCryptExportKey.argtypes = [NCRYPT_KEY_HANDLE, NCRYPT_KEY_HANDLE, ctypes.c_wchar_p, ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, ctypes.POINTER(wintypes.DWORD), wintypes.DWORD] + ncrypt.NCryptExportKey.restype = SECURITY_STATUS + + ncrypt.NCryptSignHash.argtypes = [NCRYPT_KEY_HANDLE, ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD, ctypes.POINTER(wintypes.DWORD), wintypes.DWORD] + ncrypt.NCryptSignHash.restype = SECURITY_STATUS + + ncrypt.NCryptFreeObject.argtypes = [ctypes.c_void_p] + ncrypt.NCryptFreeObject.restype = SECURITY_STATUS + + # Crypt32 prototypes + crypt32.CertCreateCertificateContext.argtypes = [wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] + crypt32.CertCreateCertificateContext.restype = PCCERT_CONTEXT + + crypt32.CertSetCertificateContextProperty.argtypes = [PCCERT_CONTEXT, wintypes.DWORD, wintypes.DWORD, ctypes.c_void_p] + crypt32.CertSetCertificateContextProperty.restype = wintypes.BOOL + + crypt32.CertFreeCertificateContext.argtypes = [PCCERT_CONTEXT] + crypt32.CertFreeCertificateContext.restype = wintypes.BOOL + + # WinHTTP prototypes + winhttp.WinHttpOpen.argtypes = [ctypes.c_wchar_p, wintypes.DWORD, ctypes.c_wchar_p, ctypes.c_wchar_p, wintypes.DWORD] + winhttp.WinHttpOpen.restype = ctypes.c_void_p + + winhttp.WinHttpConnect.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, wintypes.WORD, wintypes.DWORD] + winhttp.WinHttpConnect.restype = ctypes.c_void_p + + winhttp.WinHttpOpenRequest.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_void_p, wintypes.DWORD] + winhttp.WinHttpOpenRequest.restype = ctypes.c_void_p + + winhttp.WinHttpSetOption.argtypes = [ctypes.c_void_p, wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD] + winhttp.WinHttpSetOption.restype = wintypes.BOOL + + winhttp.WinHttpSendRequest.argtypes = [ctypes.c_void_p, ctypes.c_wchar_p, wintypes.DWORD, ctypes.c_void_p, wintypes.DWORD, wintypes.DWORD, ctypes.c_ulonglong] + winhttp.WinHttpSendRequest.restype = wintypes.BOOL + + winhttp.WinHttpReceiveResponse.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + winhttp.WinHttpReceiveResponse.restype = wintypes.BOOL + + winhttp.WinHttpQueryHeaders.argtypes = [ctypes.c_void_p, wintypes.DWORD, ctypes.c_wchar_p, ctypes.c_void_p, ctypes.POINTER(wintypes.DWORD), ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpQueryHeaders.restype = wintypes.BOOL + + winhttp.WinHttpQueryDataAvailable.argtypes = [ctypes.c_void_p, ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpQueryDataAvailable.restype = wintypes.BOOL + + winhttp.WinHttpReadData.argtypes = [ctypes.c_void_p, ctypes.c_void_p, wintypes.DWORD, ctypes.POINTER(wintypes.DWORD)] + winhttp.WinHttpReadData.restype = wintypes.BOOL + + winhttp.WinHttpCloseHandle.argtypes = [ctypes.c_void_p] + winhttp.WinHttpCloseHandle.restype = wintypes.BOOL + + _WIN32 = { + "ctypes": ctypes, + "wintypes": wintypes, + "ncrypt": ncrypt, + "crypt32": crypt32, + "winhttp": winhttp, + "NCRYPT_PROV_HANDLE": NCRYPT_PROV_HANDLE, + "NCRYPT_KEY_HANDLE": NCRYPT_KEY_HANDLE, + "SECURITY_STATUS": SECURITY_STATUS, + "CERT_CONTEXT": CERT_CONTEXT, + "PCCERT_CONTEXT": PCCERT_CONTEXT, + "BCRYPT_PSS_PADDING_INFO": BCRYPT_PSS_PADDING_INFO, + "ERROR_SUCCESS": ERROR_SUCCESS, + "NCRYPT_OVERWRITE_KEY_FLAG": NCRYPT_OVERWRITE_KEY_FLAG, + "NCRYPT_LENGTH_PROPERTY": NCRYPT_LENGTH_PROPERTY, + "NCRYPT_EXPORT_POLICY_PROPERTY": NCRYPT_EXPORT_POLICY_PROPERTY, + "NCRYPT_KEY_USAGE_PROPERTY": NCRYPT_KEY_USAGE_PROPERTY, + "NCRYPT_ALLOW_SIGNING_FLAG": NCRYPT_ALLOW_SIGNING_FLAG, + "NCRYPT_ALLOW_DECRYPT_FLAG": NCRYPT_ALLOW_DECRYPT_FLAG, + "BCRYPT_PAD_PSS": BCRYPT_PAD_PSS, + "BCRYPT_SHA256_ALGORITHM": BCRYPT_SHA256_ALGORITHM, + "BCRYPT_RSA_ALGORITHM": BCRYPT_RSA_ALGORITHM, + "BCRYPT_RSAPUBLIC_BLOB": BCRYPT_RSAPUBLIC_BLOB, + "BCRYPT_RSAPUBLIC_MAGIC": BCRYPT_RSAPUBLIC_MAGIC, + "X509_ASN_ENCODING": X509_ASN_ENCODING, + "PKCS_7_ASN_ENCODING": PKCS_7_ASN_ENCODING, + "CERT_NCRYPT_KEY_HANDLE_PROP_ID": CERT_NCRYPT_KEY_HANDLE_PROP_ID, + "CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG": CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG, + "WINHTTP_ACCESS_TYPE_DEFAULT_PROXY": WINHTTP_ACCESS_TYPE_DEFAULT_PROXY, + "WINHTTP_FLAG_SECURE": WINHTTP_FLAG_SECURE, + "WINHTTP_OPTION_CLIENT_CERT_CONTEXT": WINHTTP_OPTION_CLIENT_CERT_CONTEXT, + "WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT": WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT, + "WINHTTP_QUERY_STATUS_CODE": WINHTTP_QUERY_STATUS_CODE, + "WINHTTP_QUERY_FLAG_NUMBER": WINHTTP_QUERY_FLAG_NUMBER, + } + return _WIN32 + + +def _raise_win32_last_error(msg: str) -> None: + """Raise MsiV2Error with the current Win32 last-error code.""" + from .managed_identity import MsiV2Error + win32 = _load_win32() + ctypes_mod = win32["ctypes"] + err = ctypes_mod.get_last_error() + detail = "" + try: + detail = ctypes_mod.FormatError(err).strip() + except Exception: + pass + if detail: + raise MsiV2Error(f"{msg} (winerror={err} {detail})") + raise MsiV2Error(f"{msg} (winerror={err})") + + +def _check_security_status(status: int, what: str) -> None: + """Check SECURITY_STATUS return codes from NCrypt (0 == success).""" + from .managed_identity import MsiV2Error + if int(status) != 0: + code_u32 = int(status) & 0xFFFFFFFF + raise MsiV2Error(f"[msi_v2] {what} failed: status=0x{code_u32:08X}") + + +def _status_u32(status: int) -> int: + """Convert signed status to unsigned 32-bit value.""" + return int(status) & 0xFFFFFFFF + + +def _is_key_not_found(status: int) -> bool: + """Check if status indicates key not found (NTE_*).""" + return _status_u32(status) in (_NTE_BAD_KEYSET, _NTE_NO_KEY, _NTE_NOT_FOUND) + + +# ---------------------------- +# DER helpers (minimal PKCS#10 CSR builder) +# ---------------------------- + +def _der_len(n: int) -> bytes: + """Encode DER length field.""" + if n < 0: + raise ValueError("DER length cannot be negative") + if n < 0x80: + return bytes([n]) + out = bytearray() + m = n + while m > 0: + out.insert(0, m & 0xFF) + m >>= 8 + return bytes([0x80 | len(out)]) + bytes(out) + + +def _der(tag: int, content: bytes) -> bytes: + """Encode DER TLV (tag-length-value).""" + return bytes([tag]) + _der_len(len(content)) + content + + +def _der_null() -> bytes: + """Encode DER NULL.""" + return b"\x05\x00" + + +def _der_integer(value: int) -> bytes: + """Encode DER INTEGER.""" + if value < 0: + raise ValueError("Only non-negative INTEGER supported") + if value == 0: + raw = b"\x00" + else: + raw = value.to_bytes((value.bit_length() + 7) // 8, "big") + if raw[0] & 0x80: + raw = b"\x00" + raw + return _der(0x02, raw) + + +def _der_oid(oid: str) -> bytes: + """Encode DER OID (e.g., "2.5.4.3").""" + parts = [int(x) for x in oid.split(".")] + if len(parts) < 2: + raise ValueError(f"Invalid OID: {oid}") + if parts[0] > 2 or parts[1] >= 40: + raise ValueError(f"Invalid OID: {oid}") + first = 40 * parts[0] + parts[1] + out = bytearray([first]) + for p in parts[2:]: + if p < 0: + raise ValueError(f"Invalid OID component: {oid}") + stack = bytearray() + if p == 0: + stack.append(0) + else: + m = p + while m > 0: + stack.insert(0, m & 0x7F) + m >>= 7 + for i in range(len(stack) - 1): + stack[i] |= 0x80 + out.extend(stack) + return _der(0x06, bytes(out)) + + +def _der_sequence(*items: bytes) -> bytes: + """Encode DER SEQUENCE.""" + return _der(0x30, b"".join(items)) + + +def _der_set(*items: bytes) -> bytes: + """Encode DER SET (sorted for canonical encoding).""" + enc = sorted(items) # DER SET requires sorting by full encoding + return _der(0x31, b"".join(enc)) + + +def _der_bitstring(data: bytes) -> bytes: + """Encode DER BITSTRING (no unused bits).""" + return _der(0x03, b"\x00" + data) # 0 unused bits + + +def _der_ia5string(value: str) -> bytes: + """Encode DER IA5String (ASCII-only).""" + raw = value.encode("ascii") + return _der(0x16, raw) + + +def _der_context_explicit(tagnum: int, inner: bytes) -> bytes: + """Encode context-specific EXPLICIT tag [tagnum].""" + return _der(0xA0 + tagnum, inner) + + +def _der_context_implicit_constructed(tagnum: int, inner_content: bytes) -> bytes: + """Context-specific IMPLICIT, constructed (CSR attributes use [0] IMPLICIT).""" + return _der(0xA0 + tagnum, inner_content) + + +def _der_name_cn_dc(cn: str, dc: str) -> bytes: + """Encode X.500 Name with CN and DC RDNs (CN UTF8String, DC IA5String if ASCII).""" + cn_atv = _der_sequence(_der_oid("2.5.4.3"), _der_utf8string(cn)) + cn_rdn = _der_set(cn_atv) + + try: + dc_value = _der_ia5string(dc) + except Exception: + dc_value = _der_utf8string(dc) + dc_atv = _der_sequence(_der_oid("0.9.2342.19200300.100.1.25"), dc_value) + dc_rdn = _der_set(dc_atv) + + return _der_sequence(cn_rdn, dc_rdn) + + +def _der_subject_public_key_info_rsa(modulus: int, exponent: int) -> bytes: + """Encode SubjectPublicKeyInfo for RSA public key.""" + rsa_pub = _der_sequence(_der_integer(modulus), _der_integer(exponent)) + alg = _der_sequence(_der_oid("1.2.840.113549.1.1.1"), _der_null()) # rsaEncryption + NULL + return _der_sequence(alg, _der_bitstring(rsa_pub)) + + +def _der_algid_rsapss_sha256() -> bytes: + """AlgorithmIdentifier for RSASSA-PSS with SHA-256, MGF1(SHA-256), saltLength=32.""" + sha256 = _der_sequence(_der_oid("2.16.840.1.101.3.4.2.1"), _der_null()) + mgf1 = _der_sequence(_der_oid("1.2.840.113549.1.1.8"), sha256) + salt_len = _der_integer(32) + trailer = _der_integer(1) + params = _der_sequence( + _der_context_explicit(0, sha256), + _der_context_explicit(1, mgf1), + _der_context_explicit(2, salt_len), + _der_context_explicit(3, trailer), + ) + return _der_sequence(_der_oid("1.2.840.113549.1.1.10"), params) + + +# ---------------------------- +# CNG/NCrypt wrappers +# ---------------------------- + +def _ncrypt_get_property(win32: Dict[str, Any], h: Any, name: str) -> bytes: + """Get NCrypt property (two-pass: query size, then read).""" + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + cb = wintypes.DWORD(0) + status = ncrypt.NCryptGetProperty(h, name, None, 0, ctypes_mod.byref(cb), 0) + if int(status) != 0 and cb.value == 0: + _check_security_status(status, f"NCryptGetProperty({name})") + if cb.value == 0: + return b"" + buf = (ctypes_mod.c_ubyte * cb.value)() + status = ncrypt.NCryptGetProperty(h, name, buf, cb.value, ctypes_mod.byref(cb), 0) + _check_security_status(status, f"NCryptGetProperty({name})") + return bytes(buf[: cb.value]) + + +def _stable_key_name(client_id: str) -> str: + """Generate a stable, CNG-safe key name from client ID.""" + base = (client_id or "").strip() + safe = [] + for ch in base: + if ch.isalnum() or ch in ("-", "_"): + safe.append(ch) + else: + safe.append("_") + # Max length: keep some headroom + return "MsalMsiV2Key_" + "".join(safe)[:90] + + +def _open_or_create_keyguard_rsa_key(win32: Dict[str, Any], *, key_name: str) -> Tuple[Any, Any, str, bool]: + """ + Open a named per-boot KeyGuard RSA key if it exists; otherwise create it. + + Returns: (prov_handle, key_handle, key_name, opened_existing) + """ + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + prov = win32["NCRYPT_PROV_HANDLE"]() + status = ncrypt.NCryptOpenStorageProvider(ctypes_mod.byref(prov), "Microsoft Software Key Storage Provider", 0) + _check_security_status(status, "NCryptOpenStorageProvider") + + key = win32["NCRYPT_KEY_HANDLE"]() + + # 1) Try open first + status = ncrypt.NCryptOpenKey(prov, ctypes_mod.byref(key), str(key_name), _AT_SIGNATURE, 0) + if int(status) == 0: + vi = _ncrypt_get_property(win32, key, "Virtual Iso") + if not vi or len(vi) < 4: + from .managed_identity import MsiV2Error + raise MsiV2Error("[msi_v2] Virtual Iso property missing/invalid; Credential Guard likely not active.") + return prov, key, str(key_name), True + + if not _is_key_not_found(status): + _check_security_status(status, f"NCryptOpenKey({key_name})") + + # 2) Create if missing + flags = ( + win32["NCRYPT_OVERWRITE_KEY_FLAG"] + | _NCRYPT_USE_VIRTUAL_ISOLATION_FLAG + | _NCRYPT_USE_PER_BOOT_KEY_FLAG + ) + + status = ncrypt.NCryptCreatePersistedKey( + prov, + ctypes_mod.byref(key), + win32["BCRYPT_RSA_ALGORITHM"], + str(key_name), + _AT_SIGNATURE, + flags, + ) + + if _status_u32(status) == _NTE_EXISTS: + # Race: another thread/process created it. + status2 = ncrypt.NCryptOpenKey(prov, ctypes_mod.byref(key), str(key_name), _AT_SIGNATURE, 0) + _check_security_status(status2, f"NCryptOpenKey({key_name}) after exists") + return prov, key, str(key_name), True + + _check_security_status(status, "NCryptCreatePersistedKey") + + # Set properties only when we created the key. + length = wintypes.DWORD(int(_RSA_KEY_SIZE)) + status = ncrypt.NCryptSetProperty(key, win32["NCRYPT_LENGTH_PROPERTY"], ctypes_mod.byref(length), ctypes_mod.sizeof(length), 0) + _check_security_status(status, "NCryptSetProperty(Length)") + + usage = wintypes.DWORD(win32["NCRYPT_ALLOW_SIGNING_FLAG"] | win32["NCRYPT_ALLOW_DECRYPT_FLAG"]) + status = ncrypt.NCryptSetProperty(key, win32["NCRYPT_KEY_USAGE_PROPERTY"], ctypes_mod.byref(usage), ctypes_mod.sizeof(usage), 0) + _check_security_status(status, "NCryptSetProperty(Key Usage)") + + export_policy = wintypes.DWORD(0) # non-exportable + status = ncrypt.NCryptSetProperty(key, win32["NCRYPT_EXPORT_POLICY_PROPERTY"], ctypes_mod.byref(export_policy), ctypes_mod.sizeof(export_policy), 0) + _check_security_status(status, "NCryptSetProperty(Export Policy)") + + status = ncrypt.NCryptFinalizeKey(key, 0) + _check_security_status(status, "NCryptFinalizeKey") + + vi = _ncrypt_get_property(win32, key, "Virtual Iso") + if not vi or len(vi) < 4: + from .managed_identity import MsiV2Error + raise MsiV2Error("[msi_v2] Virtual Iso property not available; Credential Guard likely not active.") + + return prov, key, str(key_name), False + + +def _ncrypt_export_rsa_public(win32: Dict[str, Any], key: Any) -> Tuple[int, int]: + """Export RSA public key (modulus, exponent) from an NCrypt key handle.""" + from .managed_identity import MsiV2Error + + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + cb = wintypes.DWORD(0) + status = ncrypt.NCryptExportKey(key, None, win32["BCRYPT_RSAPUBLIC_BLOB"], None, None, 0, ctypes_mod.byref(cb), 0) + if int(status) != 0 and cb.value == 0: + _check_security_status(status, "NCryptExportKey(size)") + if cb.value == 0: + raise MsiV2Error("[msi_v2] NCryptExportKey returned empty public blob size") + + buf = (ctypes_mod.c_ubyte * cb.value)() + status = ncrypt.NCryptExportKey(key, None, win32["BCRYPT_RSAPUBLIC_BLOB"], None, buf, cb.value, ctypes_mod.byref(cb), 0) + _check_security_status(status, "NCryptExportKey(RSAPUBLICBLOB)") + blob = bytes(buf[: cb.value]) + + if len(blob) < 24: + raise MsiV2Error("[msi_v2] RSAPUBLICBLOB too small") + + import struct + magic, bitlen, cb_exp, cb_mod, cb_p1, cb_p2 = struct.unpack("<6I", blob[:24]) + if magic != win32["BCRYPT_RSAPUBLIC_MAGIC"]: + raise MsiV2Error(f"[msi_v2] RSAPUBLICBLOB magic mismatch: 0x{magic:08X}") + if cb_p1 != 0 or cb_p2 != 0: + logger.debug("[msi_v2] RSAPUBLICBLOB contains primes unexpectedly (ignored).") + + offset = 24 + if len(blob) < offset + cb_exp + cb_mod: + raise MsiV2Error("[msi_v2] RSAPUBLICBLOB truncated") + + exp_bytes = blob[offset : offset + cb_exp] + offset += cb_exp + mod_bytes = blob[offset : offset + cb_mod] + + exponent = int.from_bytes(exp_bytes, "big") + modulus = int.from_bytes(mod_bytes, "big") + + if bitlen != modulus.bit_length(): + logger.debug("[msi_v2] RSA bit length mismatch: header=%d computed=%d", bitlen, modulus.bit_length()) + + return modulus, exponent + + +def _ncrypt_sign_pss_sha256(win32: Dict[str, Any], key: Any, digest: bytes) -> bytes: + """Sign a SHA-256 digest using RSA-PSS via NCryptSignHash.""" + from .managed_identity import MsiV2Error + + if len(digest) != 32: + raise MsiV2Error("[msi_v2] Expected SHA-256 digest (32 bytes)") + + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + ncrypt = win32["ncrypt"] + + PaddingInfo = win32["BCRYPT_PSS_PADDING_INFO"] + pad = PaddingInfo(win32["BCRYPT_SHA256_ALGORITHM"], 32) + + hash_buf = (ctypes_mod.c_ubyte * len(digest)).from_buffer_copy(digest) + + cb_sig = wintypes.DWORD(0) + status = ncrypt.NCryptSignHash(key, ctypes_mod.byref(pad), hash_buf, len(digest), None, 0, ctypes_mod.byref(cb_sig), win32["BCRYPT_PAD_PSS"]) + if int(status) != 0 and cb_sig.value == 0: + _check_security_status(status, "NCryptSignHash(size)") + if cb_sig.value == 0: + raise MsiV2Error("[msi_v2] NCryptSignHash returned empty signature size") + + sig_buf = (ctypes_mod.c_ubyte * cb_sig.value)() + status = ncrypt.NCryptSignHash(key, ctypes_mod.byref(pad), hash_buf, len(digest), sig_buf, cb_sig.value, ctypes_mod.byref(cb_sig), win32["BCRYPT_PAD_PSS"]) + _check_security_status(status, "NCryptSignHash") + return bytes(sig_buf[: cb_sig.value]) + + +# ---------------------------- +# CSR builder (KeyGuard key handle) +# ---------------------------- + +def _build_csr_b64(win32: Dict[str, Any], key: Any, client_id: str, tenant_id: str, cu_id: Any) -> str: + """ + Build CSR signed by KeyGuard key (RSA-PSS SHA256), including cuId request attribute. + """ + modulus, exponent = _ncrypt_export_rsa_public(win32, key) + + subject = _der_name_cn_dc(client_id, tenant_id) + spki = _der_subject_public_key_info_rsa(modulus, exponent) + + cuid_json = json.dumps(cu_id, separators=(",", ":"), ensure_ascii=False) + cuid_val = _der_utf8string(cuid_json) + + # Attribute: SEQUENCE { OID, SET { } } + attr = _der_sequence(_der_oid(_CU_ID_OID_STR), _der_set(cuid_val)) + + # attributes [0] IMPLICIT SET OF Attribute + attrs_content = b"".join(sorted([attr])) + attrs = _der_context_implicit_constructed(0, attrs_content) + + cri = _der_sequence(_der_integer(0), subject, spki, attrs) + + digest = hashlib.sha256(cri).digest() + signature = _ncrypt_sign_pss_sha256(win32, key, digest) + + csr = _der_sequence(cri, _der_algid_rsapss_sha256(), _der_bitstring(signature)) + return base64.b64encode(csr).decode("ascii") + + +# ---------------------------- +# Certificate binding + WinHTTP mTLS +# ---------------------------- + +def _create_cert_context_with_key( + win32: Dict[str, Any], + cert_der: bytes, + key: Any, + key_name: str, + *, + ksp_name: str = "Microsoft Software Key Storage Provider", +) -> Tuple[Any, Any, Tuple[Any, ...]]: + """ + Create a CERT_CONTEXT from DER bytes and associate it with a CNG private key. + + We set multiple properties to maximize compatibility across Schannel consumers: + - CERT_NCRYPT_KEY_HANDLE_PROP_ID (78): direct handle + - CERT_KEY_CONTEXT_PROP_ID (5): CERT_KEY_CONTEXT union (best-effort) + - CERT_KEY_PROV_INFO_PROP_ID (2): CRYPT_KEY_PROV_INFO so Schannel can reopen by name + + NOTE: For CNG keys (dwProvType==0), CRYPT_KEY_PROV_INFO.dwKeySpec must be AT_SIGNATURE or AT_KEYEXCHANGE. + """ + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + crypt32 = win32["crypt32"] + + enc = win32["X509_ASN_ENCODING"] | win32["PKCS_7_ASN_ENCODING"] + + buf = ctypes_mod.create_string_buffer(cert_der) + ctx = crypt32.CertCreateCertificateContext(enc, buf, len(cert_der)) + if not ctx: + _raise_win32_last_error("[msi_v2] CertCreateCertificateContext failed") + + keepalive: List[Any] = [buf] + + try: + # (A) direct handle + key_handle = ctypes_mod.c_void_p(int(key.value)) + keepalive.append(key_handle) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, + win32["CERT_NCRYPT_KEY_HANDLE_PROP_ID"], + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(key_handle), + ) + if not ok: + _raise_win32_last_error("[msi_v2] CertSetCertificateContextProperty(CERT_NCRYPT_KEY_HANDLE_PROP_ID) failed") + + # (B) CERT_KEY_CONTEXT_PROP_ID (best-effort) + CERT_KEY_CONTEXT_PROP_ID = 5 + CERT_NCRYPT_KEY_SPEC = 0xFFFFFFFF + + class CERT_KEY_CONTEXT(ctypes_mod.Structure): + _fields_ = [ + ("cbSize", wintypes.DWORD), + ("hCryptProvOrNCryptKey", ctypes_mod.c_void_p), + ("dwKeySpec", wintypes.DWORD), + ] + + key_ctx = CERT_KEY_CONTEXT(ctypes_mod.sizeof(CERT_KEY_CONTEXT), key_handle, wintypes.DWORD(CERT_NCRYPT_KEY_SPEC)) + keepalive.append(key_ctx) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, + CERT_KEY_CONTEXT_PROP_ID, + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(key_ctx), + ) + if not ok: + logger.debug("[msi_v2] Failed to set CERT_KEY_CONTEXT_PROP_ID (last_error=%s)", ctypes_mod.get_last_error()) + + # (C) CERT_KEY_PROV_INFO_PROP_ID (so Schannel can reopen by name) + CERT_KEY_PROV_INFO_PROP_ID = 2 + + class CRYPT_KEY_PROV_INFO(ctypes_mod.Structure): + _fields_ = [ + ("pwszContainerName", wintypes.LPWSTR), + ("pwszProvName", wintypes.LPWSTR), + ("dwProvType", wintypes.DWORD), + ("dwFlags", wintypes.DWORD), + ("cProvParam", wintypes.DWORD), + ("rgProvParam", ctypes_mod.c_void_p), + ("dwKeySpec", wintypes.DWORD), + ] + + container_buf = ctypes_mod.create_unicode_buffer(str(key_name)) + provider_buf = ctypes_mod.create_unicode_buffer(str(ksp_name)) + keepalive.extend([container_buf, provider_buf]) + + prov_info = CRYPT_KEY_PROV_INFO( + ctypes_mod.cast(container_buf, wintypes.LPWSTR), + ctypes_mod.cast(provider_buf, wintypes.LPWSTR), + wintypes.DWORD(0), # dwProvType=0 for CNG/KSP + wintypes.DWORD(_NCRYPT_SILENT_FLAG), + wintypes.DWORD(0), + None, + wintypes.DWORD(_AT_SIGNATURE), # IMPORTANT + ) + keepalive.append(prov_info) + + ok = crypt32.CertSetCertificateContextProperty( + ctx, + CERT_KEY_PROV_INFO_PROP_ID, + win32["CERT_SET_PROPERTY_INHIBIT_PERSIST_FLAG"], + ctypes_mod.byref(prov_info), + ) + if not ok: + logger.debug("[msi_v2] Failed to set CERT_KEY_PROV_INFO_PROP_ID (last_error=%s)", ctypes_mod.get_last_error()) + + return ctx, buf, tuple(keepalive) + + except Exception: + try: + crypt32.CertFreeCertificateContext(ctx) + except Exception: + pass + raise + + +def _winhttp_close(win32: Dict[str, Any], h: Any) -> None: + """Close WinHTTP handle safely.""" + try: + if h: + win32["winhttp"].WinHttpCloseHandle(h) + except Exception: + pass + + +def _winhttp_post(win32: Dict[str, Any], url: str, cert_ctx: Any, body: bytes, headers: Dict[str, str]) -> Tuple[int, bytes]: + """ + POST bytes to https URL using WinHTTP + SChannel, presenting the provided cert context. + """ + from .managed_identity import MsiV2Error + from urllib.parse import urlparse + + ctypes_mod = win32["ctypes"] + wintypes = win32["wintypes"] + winhttp = win32["winhttp"] + + u = urlparse(url) + if u.scheme.lower() != "https": + raise MsiV2Error(f"[msi_v2] Token endpoint must be https, got: {url!r}") + if not u.hostname: + raise MsiV2Error(f"[msi_v2] Invalid token endpoint: {url!r}") + + host = u.hostname + port = u.port or 443 + path = u.path or "/" + if u.query: + path += "?" + u.query + + h_session = winhttp.WinHttpOpen("msal-python-msi-v2", win32["WINHTTP_ACCESS_TYPE_DEFAULT_PROXY"], None, None, 0) + if not h_session: + _raise_win32_last_error("[msi_v2] WinHttpOpen failed") + + h_connect = None + h_request = None + try: + # Best-effort: allow WinHTTP to send client cert for HTTP/2 negotiated connections. + enable = wintypes.DWORD(1) + try: + winhttp.WinHttpSetOption(h_session, win32["WINHTTP_OPTION_ENABLE_HTTP2_PLUS_CLIENT_CERT"], ctypes_mod.byref(enable), ctypes_mod.sizeof(enable)) + except Exception: + pass + + h_connect = winhttp.WinHttpConnect(h_session, host, int(port), 0) + if not h_connect: + _raise_win32_last_error("[msi_v2] WinHttpConnect failed") + + h_request = winhttp.WinHttpOpenRequest(h_connect, "POST", path, None, None, None, win32["WINHTTP_FLAG_SECURE"]) + if not h_request: + _raise_win32_last_error("[msi_v2] WinHttpOpenRequest failed") + + # Attach cert context (mTLS). + CertContext = win32["CERT_CONTEXT"] + ok = winhttp.WinHttpSetOption(h_request, win32["WINHTTP_OPTION_CLIENT_CERT_CONTEXT"], cert_ctx, ctypes_mod.sizeof(CertContext)) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpSetOption(WINHTTP_OPTION_CLIENT_CERT_CONTEXT) failed") + + header_lines = "".join(f"{k}: {v}\r\n" for k, v in headers.items()) + + body_buf = ctypes_mod.create_string_buffer(body) + ok = winhttp.WinHttpSendRequest(h_request, header_lines, 0xFFFFFFFF, body_buf, len(body), len(body), 0) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpSendRequest failed") + + ok = winhttp.WinHttpReceiveResponse(h_request, None) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpReceiveResponse failed") + + status = wintypes.DWORD(0) + status_size = wintypes.DWORD(ctypes_mod.sizeof(status)) + index = wintypes.DWORD(0) + + ok = winhttp.WinHttpQueryHeaders( + h_request, + win32["WINHTTP_QUERY_STATUS_CODE"] | win32["WINHTTP_QUERY_FLAG_NUMBER"], + None, + ctypes_mod.byref(status), + ctypes_mod.byref(status_size), + ctypes_mod.byref(index), + ) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpQueryHeaders(WINHTTP_QUERY_STATUS_CODE) failed") + + chunks: List[bytes] = [] + while True: + avail = wintypes.DWORD(0) + ok = winhttp.WinHttpQueryDataAvailable(h_request, ctypes_mod.byref(avail)) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpQueryDataAvailable failed") + if avail.value == 0: + break + buf = (ctypes_mod.c_ubyte * avail.value)() + read = wintypes.DWORD(0) + ok = winhttp.WinHttpReadData(h_request, buf, avail.value, ctypes_mod.byref(read)) + if not ok: + _raise_win32_last_error("[msi_v2] WinHttpReadData failed") + if read.value: + chunks.append(bytes(buf[: read.value])) + if read.value == 0: + break + + return int(status.value), b"".join(chunks) + finally: + _winhttp_close(win32, h_request) + _winhttp_close(win32, h_connect) + _winhttp_close(win32, h_session) + + +def _acquire_token_mtls_schannel(win32: Dict[str, Any], token_endpoint: str, cert_ctx: Any, client_id: str, scope: str) -> Dict[str, Any]: + """Acquire an mtls_pop token from ESTS using WinHTTP/SChannel.""" + from .managed_identity import MsiV2Error + from urllib.parse import urlencode + + form = urlencode({ + "grant_type": "client_credentials", + "client_id": client_id, + "scope": scope, + "token_type": "mtls_pop", + }).encode("utf-8") + + status, resp_body = _winhttp_post( + win32, + token_endpoint, + cert_ctx, + form, + headers={"Content-Type": "application/x-www-form-urlencoded", "Accept": "application/json"}, + ) + + text = resp_body.decode("utf-8", errors="replace") + if status < 200 or status >= 300: + raise MsiV2Error(f"[msi_v2] ESTS token request failed: HTTP {status} Body={text!r}") + return _json_loads(text, "ESTS token") + + +# ---------------------------- +# Public API +# ---------------------------- + +def obtain_token( + http_client, + managed_identity: Dict[str, Any], + resource: str, + *, + attestation_enabled: bool = True, +) -> Dict[str, Any]: + """ + Acquire mtls_pop token using Windows KeyGuard + MAA attestation. + + Flow: + 1. getplatformmetadata: fetch client_id, tenant_id, cu_id, attestationEndpoint + 2. Open/create named per-boot KeyGuard RSA key (non-exportable) + 3. Build PKCS#10 CSR with cuId attribute, sign with RSA-PSS/SHA256 + 4. Get attestation JWT from MAA (or cached) + 5. issuecredential: submit CSR + attestation → get X.509 cert + 6. Create CERT_CONTEXT, bind to KeyGuard private key + 7. POST /oauth2/v2.0/token via WinHTTP/SChannel with mTLS + + Returns: + { + "access_token": "...", + "expires_in": 3600, + "token_type": "mtls_pop", + "resource": "...", + "cert_pem": "-----BEGIN CERTIFICATE-----\\n...", # For mTLS with resource + "cert_der_b64": "base64-encoded...", # DER format + "cert_thumbprint_sha256": "...", # For verification + } + + Raises: + MsiV2Error: on any failure (no fallback to MSI v1) + """ + from .managed_identity import MsiV2Error + + win32 = _load_win32() + ncrypt = win32["ncrypt"] + crypt32 = win32["crypt32"] + + base = _imds_base() + params = _mi_query_params(managed_identity) + corr = _new_correlation_id() + + prov = None + key = None + cert_ctx = None + cert_der = None # Track for return + + try: + # 1) getplatformmetadata: fetch metadata (client_id, tenant_id, cu_id, attestationEndpoint) + meta_url = base + _CSR_METADATA_PATH + meta = _imds_get_json(http_client, meta_url, params, _imds_headers(corr)) + + client_id = _get_first(meta, "clientId", "client_id") + tenant_id = _get_first(meta, "tenantId", "tenant_id") + cu_id = meta.get("cuId") if "cuId" in meta else meta.get("cu_id") + attestation_endpoint = _get_first(meta, "attestationEndpoint", "attestation_endpoint") + + if not client_id or not tenant_id or cu_id is None: + raise MsiV2Error(f"[msi_v2] getplatformmetadata missing required fields: {meta}") + + # 2) Open-or-create named per-boot KeyGuard RSA key + key_name = os.getenv(_KEY_NAME_ENVVAR) or _stable_key_name(str(client_id)) + prov, key, key_name, opened = _open_or_create_keyguard_rsa_key(win32, key_name=key_name) + logger.debug("[msi_v2] KeyGuard key name=%s opened_existing=%s", key_name, opened) + + # 3) Build PKCS#10 CSR with cuId attribute, signed by KeyGuard key (RSA-PSS/SHA256) + csr_b64 = _build_csr_b64(win32, key, str(client_id), str(tenant_id), cu_id) + + # 4) Attestation (required for KeyGuard flow in this scenario) + if not attestation_enabled: + raise MsiV2Error("[msi_v2] attestation_enabled must be True for this KeyGuard flow.") + if not attestation_endpoint: + raise MsiV2Error("[msi_v2] attestationEndpoint missing from metadata.") + + # Get attestation JWT from MAA (with stable cache key so it persists across opens/closes) + from .msi_v2_attestation import get_attestation_jwt + att_jwt = get_attestation_jwt( + attestation_endpoint=str(attestation_endpoint), + client_id=str(client_id), + key_handle=int(key.value), + cache_key=f"{key_name}", # stable per boot + ) + if not att_jwt or not str(att_jwt).strip(): + raise MsiV2Error("[msi_v2] Attestation token is missing/empty; refusing to call issuecredential.") + + # 5) issuecredential: submit CSR + attestation JWT → get X.509 cert + issue_url = base + _ISSUE_CREDENTIAL_PATH + issue_headers = _imds_headers(corr) + issue_headers["Content-Type"] = "application/json" + + cred = _imds_post_json(http_client, issue_url, params, issue_headers, {"csr": csr_b64, "attestation_token": att_jwt}) + + cert_b64 = _get_first(cred, "certificate", "Certificate") + if not cert_b64: + raise MsiV2Error(f"[msi_v2] issuecredential missing certificate: {cred}") + + try: + cert_der = base64.b64decode(cert_b64) + except Exception as exc: + raise MsiV2Error("[msi_v2] issuecredential returned invalid base64 certificate") from exc + + canonical_client_id = _get_first(cred, "client_id", "clientId") or str(client_id) + token_endpoint = _token_endpoint_from_credential(cred) + + # 6) Create CERT_CONTEXT and bind to KeyGuard private key + cert_ctx, _, _ = _create_cert_context_with_key(win32, cert_der, key, key_name) + scope = _resource_to_scope(resource) + + # 7) POST /oauth2/v2.0/token via WinHTTP/SChannel with mTLS (client cert presentation) + token_json = _acquire_token_mtls_schannel(win32, token_endpoint, cert_ctx, canonical_client_id, scope) + + # Return token + certificate for mTLS use with resource + if token_json.get("access_token") and token_json.get("expires_in"): + cert_pem = _der_to_pem(cert_der) + cert_thumbprint = get_cert_thumbprint_sha256(cert_pem) + + return { + "access_token": token_json["access_token"], + "expires_in": int(token_json["expires_in"]), + "token_type": token_json.get("token_type") or "mtls_pop", + "resource": token_json.get("resource"), + # Certificate for mTLS with resource + "cert_pem": cert_pem, + "cert_der_b64": base64.b64encode(cert_der).decode("ascii"), + "cert_thumbprint_sha256": cert_thumbprint, + } + return token_json + + finally: + # Cleanup: release handles (key persists for the boot because it is named + per-boot) + # Crypt32: free cert context (request-scoped) + try: + if cert_ctx: + crypt32.CertFreeCertificateContext(cert_ctx) + except Exception: + pass + + # NCrypt: release provider and key handles (key persists) + try: + if key: + ncrypt.NCryptFreeObject(key) + except Exception: + pass + try: + if prov: + ncrypt.NCryptFreeObject(prov) + except Exception: + pass diff --git a/msal/msi_v2_attestation.py b/msal/msi_v2_attestation.py new file mode 100644 index 00000000..e0fe2f7b --- /dev/null +++ b/msal/msi_v2_attestation.py @@ -0,0 +1,293 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +""" +Windows attestation for MSI v2 KeyGuard keys using AttestationClientLib.dll. + +This module calls into AttestationClientLib.dll to mint an attestation JWT for a KeyGuard key handle. +It also provides a small in-memory cache to reuse the attestation JWT until ~90% of its lifetime. + +Caching notes: + - Cache is process-local (in-memory). It does not persist across process restarts. + - Cache is keyed by (attestation_endpoint, client_id, cache_key/auth_token/payload). + Provide a stable cache_key (e.g., the named per-boot key name) to maximize hits. + - If the token cannot be parsed or has no exp claim, it is not cached. + +Env vars: + - ATTESTATION_CLIENTLIB_PATH: absolute path to AttestationClientLib.dll (optional) + - MSAL_MSI_V2_ATTESTATION_CACHE: "0" disables caching (default enabled) +""" + +from __future__ import annotations + +import base64 +import ctypes +import json +import logging +import os +import sys +import threading +import time +from ctypes import POINTER, Structure, c_char_p, c_int, c_void_p +from dataclasses import dataclass +from typing import Optional, Tuple + +logger = logging.getLogger(__name__) + +_NATIVE_LOG_CB = None + +# void LogFunc(void* ctx, const char* tag, int lvl, const char* func, int line, const char* msg); +_LogFunc = ctypes.CFUNCTYPE(None, c_void_p, c_char_p, c_int, c_char_p, c_int, c_char_p) + + +class AttestationLogInfo(Structure): + _fields_ = [("Log", c_void_p), ("Ctx", c_void_p)] + + +def _default_logger(ctx, tag, lvl, func, line, msg): + try: + tag_s = tag.decode("utf-8", errors="replace") if tag else "" + func_s = func.decode("utf-8", errors="replace") if func else "" + msg_s = msg.decode("utf-8", errors="replace") if msg else "" + logger.debug("[Native:%s:%s] %s:%s - %s", tag_s, lvl, func_s, line, msg_s) + except Exception: + pass + + +def _truthy_env(name: str, default: str = "1") -> bool: + val = os.getenv(name, default) + return (val or "").strip().lower() in ("1", "true", "yes", "y", "on") + + +def _maybe_add_dll_dirs(): + """Make DLL resolution more reliable (especially for packaged apps).""" + if sys.platform != "win32": + return + + add_dir = getattr(os, "add_dll_directory", None) + if not add_dir: + return + + for p in (os.path.dirname(sys.executable), os.getcwd(), os.path.dirname(__file__)): + try: + if p and os.path.isdir(p): + add_dir(p) + except Exception: + pass + + +def _load_lib(): + from .managed_identity import MsiV2Error + + if sys.platform != "win32": + raise MsiV2Error("[msi_v2_attestation] AttestationClientLib is Windows-only.") + + _maybe_add_dll_dirs() + + explicit = os.getenv("ATTESTATION_CLIENTLIB_PATH") + try: + if explicit: + return ctypes.CDLL(explicit) + return ctypes.CDLL("AttestationClientLib.dll") + except OSError as exc: + raise MsiV2Error( + "[msi_v2_attestation] Unable to load AttestationClientLib.dll. " + "Place it next to the app/exe or set ATTESTATION_CLIENTLIB_PATH." + ) from exc + + +def _b64url_decode(s: str) -> bytes: + s = (s or "").strip() + s += "=" * ((4 - len(s) % 4) % 4) + return base64.urlsafe_b64decode(s.encode("ascii")) + + +def _try_extract_exp_iat(jwt: str) -> Tuple[Optional[int], Optional[int]]: + """ + Extract exp and iat (Unix seconds) from a JWT without validating signature. + Returns (exp, iat). Either can be None. + """ + try: + parts = jwt.split(".") + if len(parts) < 2: + return None, None + payload = json.loads(_b64url_decode(parts[1]).decode("utf-8", errors="replace")) + if not isinstance(payload, dict): + return None, None + + def _to_int(v): + if isinstance(v, bool): + return None + if isinstance(v, int): + return v + if isinstance(v, float): + return int(v) + if isinstance(v, str) and v.strip().isdigit(): + return int(v.strip()) + return None + + exp = _to_int(payload.get("exp")) + iat = _to_int(payload.get("iat")) + return exp, iat + except Exception: + return None, None + + +@dataclass(frozen=True) +class _CacheKey: + attestation_endpoint: str + client_id: str + cache_key: str + auth_token: str + client_payload: str + + +@dataclass +class _CacheEntry: + jwt: str + exp: int + refresh_after: float # epoch seconds when we should refresh + + +_CACHE_LOCK = threading.Lock() +_CACHE: dict[_CacheKey, _CacheEntry] = {} + + +def _cache_lookup(key: _CacheKey) -> Optional[str]: + if not _truthy_env("MSAL_MSI_V2_ATTESTATION_CACHE", "1"): + return None + + now = time.time() + with _CACHE_LOCK: + entry = _CACHE.get(key) + if not entry: + return None + if now >= entry.refresh_after or now >= entry.exp - 5: + return None + return entry.jwt + + +def _cache_store(key: _CacheKey, jwt: str) -> None: + if not _truthy_env("MSAL_MSI_V2_ATTESTATION_CACHE", "1"): + return + + exp, iat = _try_extract_exp_iat(jwt) + if exp is None: + return + + now = int(time.time()) + issued_at = iat if iat is not None else now + lifetime = exp - issued_at + if lifetime <= 0: + return + + # Refresh at 90% of lifetime, with a small absolute guard. + refresh_after = issued_at + (0.90 * lifetime) + refresh_after = min(refresh_after, exp - 10) + + with _CACHE_LOCK: + _CACHE[key] = _CacheEntry(jwt=jwt, exp=exp, refresh_after=float(refresh_after)) + + +def get_attestation_jwt( + *, + attestation_endpoint: str, + client_id: str, + key_handle: int, + auth_token: str = "", + client_payload: str = "{}", + cache_key: Optional[str] = None, +) -> str: + """ + Returns attestation JWT string. Raises MsiV2Error on failure. + + cache_key: + - Optional stable identifier used for caching (recommended: named per-boot key name). + - If not provided, key_handle is used as part of the cache key (less cache-friendly + if the key is opened/closed between calls). + """ + from .managed_identity import MsiV2Error + + if not attestation_endpoint: + raise MsiV2Error("[msi_v2_attestation] attestation_endpoint must be non-empty") + if not client_id: + raise MsiV2Error("[msi_v2_attestation] client_id must be non-empty") + if not key_handle: + raise MsiV2Error("[msi_v2_attestation] key_handle must be non-zero") + + stable_cache_key = cache_key if cache_key is not None else f"handle:{int(key_handle)}" + ck = _CacheKey( + attestation_endpoint=str(attestation_endpoint), + client_id=str(client_id), + cache_key=str(stable_cache_key), + auth_token=str(auth_token or ""), + client_payload=str(client_payload or "{}"), + ) + + cached = _cache_lookup(ck) + if cached: + return cached + + lib = _load_lib() + + lib.InitAttestationLib.argtypes = [POINTER(AttestationLogInfo)] + lib.InitAttestationLib.restype = c_int + + lib.AttestKeyGuardImportKey.argtypes = [ + c_char_p, # endpoint + c_char_p, # authToken + c_char_p, # clientPayload + c_void_p, # keyHandle + POINTER(c_void_p), # out token (char*) + c_char_p, # clientId + ] + lib.AttestKeyGuardImportKey.restype = c_int + + lib.FreeAttestationToken.argtypes = [c_void_p] + lib.FreeAttestationToken.restype = None + + lib.UninitAttestationLib.argtypes = [] + lib.UninitAttestationLib.restype = None + + global _NATIVE_LOG_CB # pylint: disable=global-statement + _NATIVE_LOG_CB = _LogFunc(_default_logger) + + info = AttestationLogInfo() + info.Log = ctypes.cast(_NATIVE_LOG_CB, c_void_p).value + info.Ctx = c_void_p(0) + + rc = lib.InitAttestationLib(ctypes.byref(info)) + if rc != 0: + raise MsiV2Error(f"[msi_v2_attestation] InitAttestationLib failed: {rc}") + + token_ptr = c_void_p() + try: + rc = lib.AttestKeyGuardImportKey( + attestation_endpoint.encode("utf-8"), + (auth_token or "").encode("utf-8"), + (client_payload or "{}").encode("utf-8"), + c_void_p(int(key_handle)), + ctypes.byref(token_ptr), + client_id.encode("utf-8"), + ) + if rc != 0: + raise MsiV2Error(f"[msi_v2_attestation] AttestKeyGuardImportKey failed: {rc}") + if not token_ptr.value: + raise MsiV2Error("[msi_v2_attestation] Attestation token pointer is NULL") + + token = ctypes.string_at(token_ptr.value).decode("utf-8", errors="replace") + if not token or "." not in token: + raise MsiV2Error("[msi_v2_attestation] Attestation token looks malformed") + + _cache_store(ck, token) + return token + finally: + try: + if token_ptr.value: + lib.FreeAttestationToken(token_ptr) + finally: + try: + lib.UninitAttestationLib() + except Exception: + pass diff --git a/msi-v2-sample.spec b/msi-v2-sample.spec new file mode 100644 index 00000000..65ba9781 --- /dev/null +++ b/msi-v2-sample.spec @@ -0,0 +1,45 @@ +# -*- mode: python ; coding: utf-8 -*- +from PyInstaller.utils.hooks import collect_all + +datas = [] +binaries = [] +hiddenimports = ['requests'] +tmp_ret = collect_all('cryptography') +datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] + + +a = Analysis( + ['run_msi_v2_once.py'], + pathex=[], + binaries=binaries, + datas=datas, + hiddenimports=hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + noarchive=False, + optimize=0, +) +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + a.binaries, + a.datas, + [], + name='msi-v2-sample', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + upx_exclude=[], + runtime_tmpdir=None, + console=True, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) diff --git a/sample/msi_v2_sample.py b/sample/msi_v2_sample.py new file mode 100644 index 00000000..b17d9978 --- /dev/null +++ b/sample/msi_v2_sample.py @@ -0,0 +1,175 @@ +""" +MSI v2 (mTLS PoP + KeyGuard Attestation) sample for MSAL Python. + +This sample requests an *attested*, certificate-bound access token (token_type=mtls_pop) +using the IMDSv2 /issuecredential endpoint and ESTS mTLS token endpoint. + +Key points (based on our E2E debugging): +- Use a resource that supports certificate-bound tokens. In this environment, Graph mTLS test + resource is supported; ARM typically is NOT (AADSTS392196). +- Run in strict mode: if mtls_pop + attestation is requested, we fail if we receive Bearer. +- Designed for Windows Azure VM where Credential Guard (VBS/KeyGuard) is available and + AttestationClientLib.dll is present. + +Environment variables: +- RESOURCE: defaults to https://mtlstb.graph.microsoft.com +- ENDPOINT: optional URL to call after acquiring token (e.g., Graph mTLS test endpoint) +- VERBOSE_LOGGING: "1"/"true" enables debug logs +- ATTESTATION_CLIENTLIB_PATH: optional absolute path to AttestationClientLib.dll (recommended) +- PYTHONNET_RUNTIME: must be "coreclr" for CSR OtherRequestAttributes (if using pythonnet path) + +Usage: + set RESOURCE=https://mtlstb.graph.microsoft.com + set ENDPOINT=https://mtlstb.graph.microsoft.com/v1.0/applications?$top=1 + python msi_v2_sample.py +""" + +import json +import logging +import os +import sys +import time + +import msal +import requests + + +# ------------------------- Logging ------------------------- + +def _truthy(s: str) -> bool: + return (s or "").strip().lower() in ("1", "true", "yes", "y", "on") + +if _truthy(os.getenv("VERBOSE_LOGGING", "")): + logging.basicConfig(level=logging.DEBUG) + logging.getLogger("msal").setLevel(logging.DEBUG) + +log = logging.getLogger("msi_v2_sample") + + +# ------------------------- Defaults ------------------------- + +# Graph mTLS test resource (known-good for mtls_pop in your environment) +DEFAULT_RESOURCE = "https://mtlstb.graph.microsoft.com" + +# ARM will often fail for mtls_pop with AADSTS392196 +ARM_RESOURCE = "https://management.azure.com/" + +RESOURCE = os.getenv("RESOURCE", DEFAULT_RESOURCE).strip().rstrip("/") +ENDPOINT = os.getenv("ENDPOINT", "").strip() + +# Token cache is optional; keep it simple for E2E +token_cache = msal.TokenCache() + +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), + token_cache=token_cache, +) + + +# ------------------------- Helpers ------------------------- + +def _print_env_hints(): + if RESOURCE.lower().startswith(ARM_RESOURCE): + print("NOTE: RESOURCE is ARM. mtls_pop usually fails for ARM with AADSTS392196.") + print(f" Try: set RESOURCE={DEFAULT_RESOURCE}") + + if sys.platform != "win32": + print("NOTE: This sample is designed for Windows KeyGuard + attestation.") + + +def _call_endpoint_bearer(endpoint: str, token_type: str, access_token: str): + """ + Simple HTTP call using Authorization header. + NOTE: If the *resource* requires client cert at the resource layer too, this may not work. + For your current E2E, token acquisition is the primary goal. + """ + headers = {"Authorization": f"{token_type} {access_token}", "Accept": "application/json"} + r = requests.get(endpoint, headers=headers, timeout=30) + try: + return r.status_code, r.headers, r.json() + except Exception: + return r.status_code, r.headers, r.text + + +# ------------------------- Main flow ------------------------- + +def acquire_mtls_pop_token_strict(): + """ + Acquire MSI v2 token in STRICT mode: + - We request mtls_proof_of_possession=True and with_attestation_support=True + - If we don't get token_type=mtls_pop, treat as failure + """ + result = client.acquire_token_for_client( + resource=RESOURCE, + mtls_proof_of_possession=True, # MSI v2 path + with_attestation_support=True, # KeyGuard attestation required for your scenario + ) + + if "access_token" not in result: + raise RuntimeError(f"Token acquisition failed: {json.dumps(result, indent=2)}") + + token_type = (result.get("token_type") or "Bearer").lower() + if token_type != "mtls_pop": + # In strict mode, bearer is a failure + raise RuntimeError( + "Strict MSI v2 requested, but got non-mtls_pop token.\n" + f"token_type={result.get('token_type')}\n" + "This usually means MSI v2 failed or you requested a resource that doesn't support " + "certificate-bound tokens.\n" + f"Try RESOURCE={DEFAULT_RESOURCE}\n" + f"Full result: {json.dumps(result, indent=2)}" + ) + + return result + + +def main_once(): + _print_env_hints() + + # For pythonnet-based CSR attribute support, coreclr is required. + # If you're running via pythonnet and hit OtherRequestAttributes issues, set: + # set PYTHONNET_RUNTIME=coreclr + if os.getenv("PYTHONNET_RUNTIME"): + log.debug("PYTHONNET_RUNTIME=%s", os.getenv("PYTHONNET_RUNTIME")) + + print("Requesting MSI v2 token (mtls_pop + attestation)...") + result = acquire_mtls_pop_token_strict() + + print("SUCCESS: token acquired") + print(" resource =", RESOURCE) + print(" token_len =", len(result["access_token"])) + + if ENDPOINT: + print("\nCalling ENDPOINT (best-effort using Authorization header):") + status, headers, body = _call_endpoint_bearer( + ENDPOINT, result.get("token_type", "mtls_pop"), result["access_token"] + ) + print(" status =", status) + # Print a small response preview + if isinstance(body, (dict, list)): + print(json.dumps(body, indent=2)[:2000]) + else: + print(str(body)[:2000]) + + +if __name__ == "__main__": + # Run once by default (simpler for debugging) + # Set LOOP=1 if you want repeated calls + loop = _truthy(os.getenv("LOOP", "")) + if not loop: + try: + main_once() + sys.exit(0) + except Exception as ex: + print("FAIL:", ex) + sys.exit(2) + + # Optional loop mode + while True: + try: + main_once() + except Exception as ex: + print("FAIL:", ex) + print("Sleeping 10 seconds... (Ctrl-C to stop)") + time.sleep(10) diff --git a/tests/test_msi_v2.py b/tests/test_msi_v2.py new file mode 100644 index 00000000..6785399f --- /dev/null +++ b/tests/test_msi_v2.py @@ -0,0 +1,321 @@ +# Copyright (c) Microsoft Corporation. +# All rights reserved. +# +# This code is licensed under the MIT License. +"""Tests for MSI v2 (mTLS PoP) implementation. + +Goals: +- Provide strong unit coverage without depending on pythonnet / KeyGuard / real IMDS. +- Avoid importing optional helpers that may not exist in the KeyGuard implementation. +- Validate: + * x5t#S256 helper correctness (local) + * verify_cnf_binding behavior (msal.msi_v2) + * ManagedIdentityClient strict gating behavior (msi v2 invoked only when explicitly requested) + * Optional IMDSv2 wire-contract helpers when present (skipped if not exposed) +""" + +import base64 +import datetime +import hashlib +import json +import os +import unittest + +try: + from unittest.mock import patch, MagicMock +except ImportError: + from mock import patch, MagicMock + +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + +import msal +from msal import MsiV2Error + + +# Import only stable surface from msal.msi_v2 +from msal.msi_v2 import verify_cnf_binding + +# MinimalResponse is used in other test modules; safe to reuse here +from tests.test_throttled_http_client import MinimalResponse + + +# --------------------------------------------------------------------------- +# Local helpers (do not rely on msal.msi_v2 exporting these) +# --------------------------------------------------------------------------- + +def _make_self_signed_cert(private_key, common_name="test"): + subject = issuer = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, common_name)]) + now = datetime.datetime.now(datetime.timezone.utc) + cert = ( + x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(now) + .not_valid_after(now + datetime.timedelta(days=1)) + .sign(private_key, hashes.SHA256(), default_backend()) + ) + return cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") + + +def get_cert_thumbprint_sha256(cert_pem: str) -> str: + """x5t#S256 = base64url(SHA256(der(cert))) without padding.""" + cert = x509.load_pem_x509_certificate(cert_pem.encode("utf-8"), default_backend()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + return base64.urlsafe_b64encode(hashlib.sha256(cert_der).digest()).rstrip(b"=").decode("ascii") + + +def _b64url(s: bytes) -> str: + return base64.urlsafe_b64encode(s).rstrip(b"=").decode("ascii") + + +def _make_jwt(payload_obj, header_obj=None) -> str: + header_obj = header_obj or {"alg": "RS256", "typ": "JWT"} + header = _b64url(json.dumps(header_obj, separators=(",", ":")).encode("utf-8")) + payload = _b64url(json.dumps(payload_obj, separators=(",", ":")).encode("utf-8")) + sig = _b64url(b"sig") + return f"{header}.{payload}.{sig}" + + +# --------------------------------------------------------------------------- +# Thumbprint helper +# --------------------------------------------------------------------------- + +class TestThumbprintHelper(unittest.TestCase): + def setUp(self): + self.key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + self.cert_pem = _make_self_signed_cert(self.key, "thumbprint-test") + + def test_returns_base64url_no_padding(self): + thumb = get_cert_thumbprint_sha256(self.cert_pem) + self.assertIsInstance(thumb, str) + self.assertNotIn("=", thumb) + + decoded = base64.urlsafe_b64decode(thumb + "==") + self.assertEqual(len(decoded), 32) + + def test_same_cert_same_thumbprint(self): + t1 = get_cert_thumbprint_sha256(self.cert_pem) + t2 = get_cert_thumbprint_sha256(self.cert_pem) + self.assertEqual(t1, t2) + + def test_different_certs_different_thumbprints(self): + key2 = rsa.generate_private_key(public_exponent=65537, key_size=2048) + cert2_pem = _make_self_signed_cert(key2, "thumbprint-test-2") + self.assertNotEqual(get_cert_thumbprint_sha256(self.cert_pem), + get_cert_thumbprint_sha256(cert2_pem)) + + def test_matches_manual_sha256_der(self): + cert = x509.load_pem_x509_certificate(self.cert_pem.encode("utf-8"), default_backend()) + cert_der = cert.public_bytes(serialization.Encoding.DER) + expected = base64.urlsafe_b64encode(hashlib.sha256(cert_der).digest()).rstrip(b"=").decode("ascii") + self.assertEqual(get_cert_thumbprint_sha256(self.cert_pem), expected) + + +# --------------------------------------------------------------------------- +# verify_cnf_binding (more coverage) +# --------------------------------------------------------------------------- + +class TestVerifyCnfBinding(unittest.TestCase): + def setUp(self): + self.key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + self.cert_pem = _make_self_signed_cert(self.key, "cnf-test") + self.thumbprint = get_cert_thumbprint_sha256(self.cert_pem) + + def test_valid_binding_true(self): + token = _make_jwt({"cnf": {"x5t#S256": self.thumbprint}}) + self.assertTrue(verify_cnf_binding(token, self.cert_pem)) + + def test_wrong_thumbprint_false(self): + token = _make_jwt({"cnf": {"x5t#S256": "wrong"}}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_missing_cnf_false(self): + token = _make_jwt({"sub": "nobody"}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_missing_x5t_false(self): + token = _make_jwt({"cnf": {}}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_cnf_not_object_false(self): + token = _make_jwt({"cnf": "not-an-object"}) + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_not_a_jwt_false(self): + self.assertFalse(verify_cnf_binding("notajwt", self.cert_pem)) + + def test_two_part_jwt_false(self): + token = "a.b" + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_four_part_jwt_false(self): + token = "a.b.c.d" + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_malformed_payload_base64_false(self): + token = "header.!!!.sig" + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_payload_not_json_false(self): + header = _b64url(b'{"alg":"none"}') + payload = _b64url(b"not-json") + token = f"{header}.{payload}.sig" + self.assertFalse(verify_cnf_binding(token, self.cert_pem)) + + def test_payload_with_padding_still_works(self): + # Create payload base64 with explicit padding (library should tolerate) + header = base64.urlsafe_b64encode(b'{"alg":"RS256"}').decode("ascii") # includes padding sometimes + payload = base64.urlsafe_b64encode(json.dumps({"cnf": {"x5t#S256": self.thumbprint}}).encode("utf-8")).decode("ascii") + token = f"{header}.{payload}.sig" + self.assertTrue(verify_cnf_binding(token, self.cert_pem)) + + def test_unicode_in_payload_does_not_break(self): + token = _make_jwt({"cnf": {"x5t#S256": self.thumbprint}, "msg": "こんにちは"}) + self.assertTrue(verify_cnf_binding(token, self.cert_pem)) + + +# --------------------------------------------------------------------------- +# ManagedIdentityClient gating + strict behavior (better coverage) +# --------------------------------------------------------------------------- + +class TestManagedIdentityClientStrictGating(unittest.TestCase): + def _make_client(self): + import requests + return msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=requests.Session(), + ) + + def test_error_is_exported(self): + self.assertIs(msal.MsiV2Error, MsiV2Error) + + def test_error_is_subclass(self): + self.assertTrue(issubclass(MsiV2Error, msal.ManagedIdentityError)) + + @patch("msal.managed_identity._obtain_token") + def test_default_path_calls_v1(self, mock_v1): + mock_v1.return_value = {"access_token": "V1", "expires_in": 3600, "token_type": "Bearer"} + client = self._make_client() + res = client.acquire_token_for_client(resource="R") + self.assertEqual(res["access_token"], "V1") + mock_v1.assert_called_once() + + def test_attestation_requires_pop(self): + client = self._make_client() + with self.assertRaises(msal.ManagedIdentityError): + client.acquire_token_for_client(resource="R", + mtls_proof_of_possession=False, + with_attestation_support=True) + + @patch("msal.msi_v2.obtain_token") + @patch("msal.managed_identity._obtain_token") + def test_pop_without_attestation_does_not_call_v2(self, mock_v1, mock_v2): + # If your implementation requires BOTH flags, v2 must not run here. + mock_v1.return_value = {"access_token": "V1", "expires_in": 3600, "token_type": "Bearer"} + client = self._make_client() + res = client.acquire_token_for_client(resource="R", + mtls_proof_of_possession=True, + with_attestation_support=False) + # depending on your design this could either raise or fall back to v1. + # If you changed to "v2 only when both flags", it should use v1. + self.assertEqual(res["token_type"], "Bearer") + mock_v2.assert_not_called() + mock_v1.assert_called_once() + + @patch("msal.msi_v2.obtain_token") + def test_v2_called_when_both_flags_true(self, mock_v2): + mock_v2.return_value = {"access_token": "V2", "expires_in": 3600, "token_type": "mtls_pop"} + client = self._make_client() + res = client.acquire_token_for_client(resource="https://mtlstb.graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True) + self.assertEqual(res["token_type"], "mtls_pop") + mock_v2.assert_called_once() + # Ensure v2 called with expected signature (resource argument passed through) + args, kwargs = mock_v2.call_args + # obtain_token(http_client, managed_identity, resource, attestation_enabled=...) + self.assertTrue(len(args) >= 3) + self.assertEqual(args[2], "https://mtlstb.graph.microsoft.com") + self.assertIn("attestation_enabled", kwargs) + self.assertTrue(kwargs["attestation_enabled"]) + + @patch("msal.msi_v2.obtain_token", side_effect=MsiV2Error("boom")) + @patch("msal.managed_identity._obtain_token") + def test_strict_v2_failure_raises_no_v1_fallback(self, mock_v1, mock_v2): + client = self._make_client() + with self.assertRaises(MsiV2Error): + client.acquire_token_for_client(resource="https://mtlstb.graph.microsoft.com", + mtls_proof_of_possession=True, + with_attestation_support=True) + mock_v1.assert_not_called() + + +# --------------------------------------------------------------------------- +# Optional: wire contract helper tests (skip if helpers not present) +# --------------------------------------------------------------------------- + +class TestImdsV2OptionalHelpers(unittest.TestCase): + def test_mi_query_params_adds_version_and_uami_selector(self): + if not hasattr(msal.msi_v2, "_mi_query_params"): + self.skipTest("msal.msi_v2._mi_query_params not exposed") + + p = msal.msi_v2._mi_query_params({"ManagedIdentityIdType": "ClientId", "Id": "abc"}) + self.assertIn("cred-api-version", p) + self.assertEqual(p["cred-api-version"], "2.0") + self.assertEqual(p.get("client_id"), "abc") + + p2 = msal.msi_v2._mi_query_params({"ManagedIdentityIdType": "ObjectId", "Id": "oid"}) + self.assertEqual(p2.get("object_id"), "oid") + + p3 = msal.msi_v2._mi_query_params({"ManagedIdentityIdType": "ResourceId", "Id": "/sub/..."}) + self.assertEqual(p3.get("msi_res_id"), "/sub/...") + + def test_issuecredential_body_uses_attestation_token(self): + if not hasattr(msal.msi_v2, "_imds_post_json"): + self.skipTest("msal.msi_v2._imds_post_json not exposed") + if not hasattr(msal.msi_v2, "_issue_credential"): + self.skipTest("msal.msi_v2._issue_credential not exposed") + + http_client = MagicMock() + http_client.post.return_value = MinimalResponse( + status_code=200, + text=json.dumps({ + "certificate": "Zg==", + "mtls_authentication_endpoint": "https://login", + "tenant_id": "t", + "client_id": "c", + }), + ) + + msal.msi_v2._issue_credential( + http_client, + managed_identity={"ManagedIdentityIdType": "SystemAssigned", "Id": None}, + csr_b64="QUJD", + attestation_jwt="fake.jwt", + ) + + _, kwargs = http_client.post.call_args + body = json.loads(kwargs["data"]) + self.assertEqual(body["csr"], "QUJD") + self.assertEqual(body["attestation_token"], "fake.jwt") + + def test_token_endpoint_derived_from_mtls_auth_endpoint(self): + if not hasattr(msal.msi_v2, "_token_endpoint_from_credential"): + self.skipTest("msal.msi_v2._token_endpoint_from_credential not exposed") + + cred = { + "mtls_authentication_endpoint": "https://login.example.com", + "tenant_id": "tenant123", + } + ep = msal.msi_v2._token_endpoint_from_credential(cred) + self.assertTrue(ep.endswith("/tenant123/oauth2/v2.0/token")) + + +if __name__ == "__main__": + unittest.main() diff --git a/working/run_msi_v2_once.py b/working/run_msi_v2_once.py new file mode 100644 index 00000000..a34558c6 --- /dev/null +++ b/working/run_msi_v2_once.py @@ -0,0 +1,56 @@ +import json +import os +import sys + +import msal +import requests + +DEFAULT_RESOURCE = "https://graph.microsoft.com" +RESOURCE = os.getenv("RESOURCE", DEFAULT_RESOURCE).strip().rstrip("/") + +session = requests.Session() +cache = msal.TokenCache() + +client = msal.ManagedIdentityClient( + msal.SystemAssignedManagedIdentity(), + http_client=session, + token_cache=cache, +) + +def acquire_mtls_pop_token_strict(): + result = client.acquire_token_for_client( + resource=RESOURCE, + mtls_proof_of_possession=True, + with_attestation_support=True, + ) + + if "access_token" not in result: + raise RuntimeError(f"Token acquisition failed: {json.dumps(result, indent=2)}") + + token_type = (result.get("token_type") or "Bearer").lower() + if token_type != "mtls_pop": + raise RuntimeError( + f"Strict MSI v2 requested, but got token_type={result.get('token_type')}. " + f"Full result: {json.dumps(result, indent=2)}" + ) + + return result + +if __name__ == "__main__": + try: + r1 = acquire_mtls_pop_token_strict() + print("token received (1)") + + r2 = acquire_mtls_pop_token_strict() + print("token received (2)") + + # If MSAL exposes a cache indicator, avoid printing its concrete value to logs + ts1 = r1.get("token_source") or r1.get("source") or "" + ts2 = r2.get("token_source") or r2.get("source") or "" + if ts1 or ts2: + print("token sources are available (not logged for security)") + + sys.exit(0) + except Exception as ex: + print("FAIL:", ex) + sys.exit(2) \ No newline at end of file